Skip to content

Commit ded8f53

Browse files
committed
transform recursion into loop for dotted imports
1 parent f176517 commit ded8f53

File tree

1 file changed

+46
-27
lines changed

1 file changed

+46
-27
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/statement/AbstractImportNode.java

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
import com.oracle.truffle.api.CompilerDirectives;
7575
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
7676
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
77+
import com.oracle.truffle.api.CompilerDirectives.ValueType;
7778
import com.oracle.truffle.api.TruffleLanguage.ContextReference;
7879
import com.oracle.truffle.api.TruffleLanguage.LanguageReference;
7980
import com.oracle.truffle.api.dsl.Bind;
@@ -86,6 +87,7 @@
8687
import com.oracle.truffle.api.library.CachedLibrary;
8788
import com.oracle.truffle.api.nodes.Node;
8889
import com.oracle.truffle.api.object.DynamicObjectLibrary;
90+
import com.oracle.truffle.api.profiles.ConditionProfile;
8991

9092
public abstract class AbstractImportNode extends StatementNode {
9193
@CompilationFinal private LanguageReference<PythonLanguage> languageRef;
@@ -232,6 +234,15 @@ static Object levelZeroNoFromlist(VirtualFrame frame, PythonContext context, Str
232234
return mod;
233235
}
234236

237+
@ValueType
238+
private static final class ModuleFront {
239+
private String front;
240+
241+
private ModuleFront(String front) {
242+
this.front = front;
243+
}
244+
}
245+
235246
@Specialization(guards = "level >= 0", replaces = "levelZeroNoFromlist")
236247
static Object genericImport(VirtualFrame frame, PythonContext context, String name, Object globals, String[] fromList, int level,
237248
@Cached ResolveName resolveName,
@@ -241,7 +252,8 @@ static Object genericImport(VirtualFrame frame, PythonContext context, String na
241252
@Cached PyObjectLookupAttr getPathNode,
242253
@Cached PyObjectCallMethodObjArgs callHandleFromlist,
243254
@Cached PythonObjectFactory factory,
244-
@Cached FindAndLoad findAndLoad) {
255+
@Cached FindAndLoad findAndLoad,
256+
@Cached ConditionProfile recursiveCase) {
245257
String absName;
246258
if (level > 0) {
247259
absName = resolveName.execute(frame, name, globals, level);
@@ -266,17 +278,18 @@ static Object genericImport(VirtualFrame frame, PythonContext context, String na
266278
return mod;
267279
}
268280
if (level == 0) {
269-
String front = name.substring(0, dotIndex);
270-
// recursion up number of dots in the name. we use a boundary call here to avoid PE going recursive
271-
return genericImportBoundary(frame.materialize(), context, front, null, PythonUtils.EMPTY_STRING_ARRAY, 0,
272-
null, // resolveName is not used with level == 0
273-
raiseNode, // raiseNode only needed if front.length() == 0 at this point
274-
getModuleNode, // used multiple times to get the 'front' module
275-
ensureInitialized, // used multiple times on the 'front' module
276-
null, // getPathNode is not used with fromList.length == 0
277-
null, // callHandleFromlist not used with fromList.length == 0
278-
null, // factory not used with fromList.length == 0
279-
findAndLoad); // used multiple times, but always to call the exact same function
281+
Object front = new ModuleFront(name.substring(0, dotIndex));
282+
// cpython recurses, we have transformed the recursion into a loop
283+
do {
284+
// we omit a few arguments in the recursion, because that makes things simpler.
285+
// globals are null, fromlist is empty, level is 0
286+
front = genericImportRecursion(frame, context, (ModuleFront)front,
287+
raiseNode, // raiseNode only needed if front.length() == 0 at this point
288+
getModuleNode, // used multiple times to get the 'front' module
289+
ensureInitialized, // used multiple times on the 'front' module
290+
findAndLoad); // used multiple times, but always to call the exact same function
291+
} while (recursiveCase.profile(front instanceof ModuleFront));
292+
return front;
280293
} else {
281294
int cutoff = nameLength - dotIndex;
282295
String toReturn = absName.substring(0, absName.length() - cutoff);
@@ -302,25 +315,31 @@ static Object genericImport(VirtualFrame frame, PythonContext context, String na
302315
}
303316
}
304317

305-
@TruffleBoundary
306-
static Object genericImportBoundary(MaterializedFrame frame, PythonContext context, String name, Object globals, String[] fromList, int level,
307-
ResolveName resolveName,
318+
static final Object genericImportRecursion(VirtualFrame frame, PythonContext context, ModuleFront front,
308319
PRaiseNode raiseNode,
309320
PyDictGetItem getModuleNode,
310321
EnsureInitializedNode ensureInitialized,
311-
PyObjectLookupAttr getPathNode,
312-
PyObjectCallMethodObjArgs callHandleFromlist,
313-
PythonObjectFactory factory,
314322
FindAndLoad findAndLoad) {
315-
return genericImport(frame, context, name, globals, fromList, level,
316-
resolveName,
317-
raiseNode,
318-
getModuleNode,
319-
ensureInitialized,
320-
getPathNode,
321-
callHandleFromlist,
322-
factory,
323-
findAndLoad);
323+
String absName = front.front;
324+
if (absName.length() == 0) {
325+
throw raiseNode.raise(PythonBuiltinClassType.ValueError, "Empty module name");
326+
}
327+
PDict sysModules = context.getSysModules();
328+
Object mod = getModuleNode.execute(frame, sysModules, absName); // import_get_module
329+
if (mod != null && mod != PNone.NONE) {
330+
ensureInitialized.execute(frame, context, mod, absName);
331+
} else {
332+
mod = findAndLoad.execute(frame, context, absName);
333+
}
334+
// fromList.length == 0
335+
// level == 0
336+
int dotIndex = absName.indexOf('.');
337+
if (dotIndex == -1) {
338+
return mod;
339+
}
340+
// level == 0
341+
front.front = absName.substring(0, dotIndex);
342+
return front;
324343
}
325344
}
326345

0 commit comments

Comments
 (0)