Skip to content

Commit 6f0f71d

Browse files
committed
Fix: implement generic argument parsing for _locale.setlocale.
1 parent 7936cbf commit 6f0f71d

File tree

1 file changed

+41
-9
lines changed

1 file changed

+41
-9
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/LocaleModuleBuiltins.java

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,22 @@
5555
import com.oracle.graal.python.builtins.PythonBuiltins;
5656
import com.oracle.graal.python.builtins.objects.PNone;
5757
import com.oracle.graal.python.builtins.objects.dict.PDict;
58+
import com.oracle.graal.python.builtins.objects.function.PArguments;
59+
import com.oracle.graal.python.builtins.objects.object.PythonObjectLibrary;
5860
import com.oracle.graal.python.nodes.ErrorMessages;
61+
import com.oracle.graal.python.nodes.PGuards;
5962
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
63+
import com.oracle.graal.python.nodes.util.CannotCastException;
64+
import com.oracle.graal.python.nodes.util.CastToJavaStringNode;
6065
import com.oracle.graal.python.runtime.exception.PythonErrorType;
6166
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
6267
import com.oracle.truffle.api.TruffleOptions;
63-
import com.oracle.truffle.api.dsl.Fallback;
68+
import com.oracle.truffle.api.dsl.Cached;
6469
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
6570
import com.oracle.truffle.api.dsl.NodeFactory;
6671
import com.oracle.truffle.api.dsl.Specialization;
72+
import com.oracle.truffle.api.frame.VirtualFrame;
73+
import com.oracle.truffle.api.library.CachedLibrary;
6774

6875
@CoreFunctions(defineModule = "_locale")
6976
public class LocaleModuleBuiltins extends PythonBuiltins {
@@ -222,9 +229,9 @@ public PDict localeconv() {
222229
public abstract static class SetLocaleNode extends PythonBuiltinNode {
223230

224231
@SuppressWarnings("fallthrough")
225-
@Specialization(guards = {"category >= 0", "category <= 6"})
232+
@Specialization(guards = "isValidCategory(category)")
226233
@TruffleBoundary
227-
public Object setLocale(int category, @SuppressWarnings("unused") PNone posixLocaleID) {
234+
Object doWithoutLocaleID(int category, @SuppressWarnings("unused") PNone posixLocaleID) {
228235
Locale defaultLocale;
229236
Locale.Category displayCategory = null;
230237
Locale.Category formatCategory = null;
@@ -258,10 +265,10 @@ public Object setLocale(int category, @SuppressWarnings("unused") PNone posixLoc
258265
return toPosix(defaultLocale);
259266
}
260267

261-
@SuppressWarnings("fallthrough")
262-
@Specialization(guards = {"category >= 0", "category <= 6"})
268+
@Specialization(guards = "isValidCategory(category)")
263269
@TruffleBoundary
264-
public Object setLocale(int category, String posixLocaleID) {
270+
@SuppressWarnings("fallthrough")
271+
Object doWithLocaleID(int category, String posixLocaleID) {
265272
Locale.Category displayCategory = null;
266273
Locale.Category formatCategory = null;
267274
if (!TruffleOptions.AOT) {
@@ -303,9 +310,34 @@ public Object setLocale(int category, String posixLocaleID) {
303310
return toPosix(newLocale);
304311
}
305312

306-
@Fallback
307-
public Object setLocale(@SuppressWarnings("unused") Object category, @SuppressWarnings("unused") Object locale) {
308-
throw raise(PythonErrorType.ValueError, ErrorMessages.INVALID_LOCALE_CATEGORY);
313+
@Specialization(replaces = {"doWithoutLocaleID", "doWithLocaleID"}, limit = "3")
314+
Object doGeneric(VirtualFrame frame, Object category, Object posixLocaleID,
315+
@CachedLibrary("category") PythonObjectLibrary categoryLib,
316+
@Cached CastToJavaStringNode castToJavaStringNode) {
317+
318+
long l = categoryLib.asJavaLongWithState(category, PArguments.getThreadState(frame));
319+
if (!isValidCategory(l)) {
320+
throw raise(PythonErrorType.ValueError, ErrorMessages.INVALID_LOCALE_CATEGORY);
321+
}
322+
323+
String posixLocaleIDStr = null;
324+
// may be NONE or NO_VALUE
325+
if (!PGuards.isPNone(posixLocaleID)) {
326+
try {
327+
posixLocaleIDStr = castToJavaStringNode.execute(posixLocaleID);
328+
} catch (CannotCastException e) {
329+
// fall through
330+
}
331+
}
332+
333+
if (posixLocaleIDStr != null) {
334+
return doWithLocaleID((int) l, posixLocaleIDStr);
335+
}
336+
return doWithoutLocaleID((int) l, PNone.NONE);
337+
}
338+
339+
static boolean isValidCategory(long l) {
340+
return 0 <= l && l <= 6;
309341
}
310342
}
311343
}

0 commit comments

Comments
 (0)