Skip to content

Commit ab84d01

Browse files
committed
simplify getaddrinfo
1 parent 98a6162 commit ab84d01

File tree

1 file changed

+42
-62
lines changed

1 file changed

+42
-62
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/SocketModuleBuiltins.java

Lines changed: 42 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
import com.oracle.truffle.api.dsl.NodeFactory;
8080
import com.oracle.truffle.api.dsl.Specialization;
8181
import com.oracle.truffle.api.frame.VirtualFrame;
82+
import com.oracle.truffle.api.profiles.BranchProfile;
8283

8384
import org.graalvm.nativeimage.ImageInfo;
8485

@@ -453,88 +454,65 @@ Object getNameInfo(PTuple sockaddr, int flags) {
453454
@Builtin(name = "getaddrinfo", parameterNames = {"host", "port", "family", "type", "proto", "flags"})
454455
@GenerateNodeFactory
455456
public abstract static class GetAddrInfoNode extends PythonBuiltinNode {
456-
@Specialization
457-
Object getAddrInfo(String host, int port, int family, int type, int proto, PInt flags) {
458-
return getAddrInfo(host, port, family, type, proto, flags.asInt());
459-
}
460-
461-
@Specialization
462-
Object getAddrInfo(String host, Integer port, PInt family, PInt type, Integer proto, Integer flags) {
463-
return getAddrInfo(host, port, family.asInt(), type.asInt(), proto, flags);
464-
}
457+
BranchProfile stringPortProfile = BranchProfile.create();
458+
BranchProfile nonePortProfile = BranchProfile.create();
459+
BranchProfile intPortProfile = BranchProfile.create();
465460

466461
@Specialization
467-
Object getAddrInfo(String host, String port, int family, int type, int proto, PInt flags) {
468-
return getAddrInfo(host, port, family, type, proto, flags.asInt());
462+
Object getAddrInfoPString(PString host, Object port, Object family, Object type, Object proto, Object flags,
463+
@Cached CastToIndexNode cast) {
464+
return getAddrInfoString(host.getValue(), port, family, type, proto, flags, cast);
469465
}
470466

471467
@Specialization
472-
@TruffleBoundary
473-
Object getAddrInfo(@SuppressWarnings("unused") PNone host, Integer port, PInt family, PInt type, Integer proto, PInt flags) {
474-
return getAddrInfo(host, port, family.intValue(), type.intValue(), proto, flags.intValue());
475-
}
476-
477-
@Specialization
478-
@TruffleBoundary
479-
Object getAddrInfo(@SuppressWarnings("unused") PNone host, String port, int family, int type, int proto, int flags) {
480-
return getAddrInfo("localhost", port, family, type, proto, flags);
468+
Object getAddrInfoNone(@SuppressWarnings("unused") PNone host, Object port, Object family, Object type, Object proto, Object flags,
469+
@Cached CastToIndexNode cast) {
470+
return getAddrInfoString("localhost", port, family, type, proto, flags, cast);
481471
}
482472

483473
@Specialization
484-
@TruffleBoundary
485-
Object getAddrInfo(@SuppressWarnings("unused") PNone host, int port, int family, int type, int proto, int flags) {
486-
return getAddrInfo("localhost", port, family, type, proto, flags);
487-
}
488-
489-
@Specialization
490-
@TruffleBoundary
491-
Object getAddrInfo(String host, @SuppressWarnings("unused") PNone port, PInt family, Integer type, Integer proto, Integer flags) {
492-
return getAddrInfo(host, port, family.intValue(), type, proto, flags);
493-
}
494-
495-
@Specialization
496-
@TruffleBoundary
497-
Object getAddrInfo(String host, @SuppressWarnings("unused") PNone port, Integer family, Integer type, Integer proto, PInt flags) {
498-
return getAddrInfo(host, port, family, type, proto, flags.intValue());
499-
}
474+
Object getAddrInfoString(String host, Object port, Object family, Object type, Object proto, Object flags,
475+
@Cached CastToIndexNode cast) {
476+
String stringPort = null;
477+
if (port instanceof PString) {
478+
stringPort = ((PString) port).getValue();
479+
} else if (port instanceof String) {
480+
stringPort = (String) port;
481+
}
500482

501-
@Specialization
502-
@TruffleBoundary
503-
Object getAddrInfo(String host, @SuppressWarnings("unused") PNone port, Integer family, PInt type, Integer proto, Integer flags) {
504-
return getAddrInfo(host, port, family, type.intValue(), proto, flags);
505-
}
483+
if (stringPort != null) {
484+
stringPortProfile.enter();
485+
return getAddrInfo(host, stringPort, cast.execute(family), cast.execute(type), cast.execute(proto), cast.execute(flags));
486+
}
506487

507-
@Specialization
508-
@TruffleBoundary
509-
Object getAddrInfo(String host, @SuppressWarnings("unused") PNone port, int family, int type, int proto, int flags) {
510-
List<Service> serviceList = new LinkedList<>();
488+
if (port instanceof PNone) {
489+
nonePortProfile.enter();
490+
InetAddress[] adresses = resolveHost(host);
491+
return mergeAdressesAndServices(adresses, null, cast.execute(family), cast.execute(type), cast.execute(proto), cast.execute(flags));
492+
}
511493

512-
InetAddress[] adresses = resolveHost(host);
513-
return mergeAdressesAndServices(adresses, serviceList, family, type, proto, flags);
494+
intPortProfile.enter();
495+
return getAddrInfo(host, cast.execute(port), cast.execute(family), cast.execute(type), cast.execute(proto), cast.execute(flags));
514496
}
515497

516-
@Specialization
517498
@TruffleBoundary
518-
Object getAddrInfo(String host, int port, int family, int type, int proto, int flags) {
499+
private Object getAddrInfo(String host, int port, int family, int type, int proto, int flags) {
519500
InetAddress[] adresses = resolveHost(host);
520501
List<Service> serviceList = new ArrayList<>();
521502
serviceList.add(new Service(port, "tcp"));
522503
serviceList.add(new Service(port, "udp"));
523504
return mergeAdressesAndServices(adresses, serviceList, family, type, proto, flags);
524505
}
525506

526-
@Specialization
527507
@TruffleBoundary
528-
Object getAddrInfo(String host, String port, int family, int type, int proto, int flags) {
508+
private Object getAddrInfo(String host, String port, int family, int type, int proto, int flags) {
529509
if (!StandardCharsets.US_ASCII.newEncoder().canEncode(port)) {
530510
throw raise(PythonBuiltinClassType.UnicodeEncodeError);
531511
}
532-
533512
if (services == null) {
534513
services = parseServices();
535514
}
536515
List<Service> serviceList = services.get(port);
537-
538516
InetAddress[] adresses = resolveHost(host);
539517
return mergeAdressesAndServices(adresses, serviceList, family, type, proto, flags);
540518
}
@@ -545,16 +523,17 @@ private Object mergeAdressesAndServices(InetAddress[] adresses, List<Service> se
545523
protocols = parseProtocols();
546524
}
547525
List<Object> addressTuples = new ArrayList<>();
548-
549526
for (InetAddress addr : adresses) {
550-
for (Service srv : serviceList) {
551-
int protocol = protocols.get(srv.protocol);
552-
if (proto != 0 && proto != protocol) {
553-
continue;
554-
}
555-
PTuple addrTuple = createAddressTuple(addr, srv.port, family, type, protocol, flags);
556-
if (addrTuple != null) {
557-
addressTuples.add(addrTuple);
527+
if (serviceList != null) {
528+
for (Service srv : serviceList) {
529+
int protocol = protocols.get(srv.protocol);
530+
if (proto != 0 && proto != protocol) {
531+
continue;
532+
}
533+
PTuple addrTuple = createAddressTuple(addr, srv.port, family, type, protocol, flags);
534+
if (addrTuple != null) {
535+
addressTuples.add(addrTuple);
536+
}
558537
}
559538
}
560539
}
@@ -585,6 +564,7 @@ private PTuple createAddressTuple(InetAddress address, int port, int family, int
585564
return factory().createTuple(new Object[]{addressFamily, addressType, proto, canonname, sockAddr});
586565
}
587566

567+
@TruffleBoundary
588568
InetAddress[] resolveHost(String host) {
589569
try {
590570
return InetAddress.getAllByName(host);

0 commit comments

Comments
 (0)