|
40 | 40 | */
|
41 | 41 | package com.oracle.graal.python.builtins.modules;
|
42 | 42 |
|
43 |
| -import java.util.ArrayList; |
| 43 | +import static com.oracle.graal.python.builtins.objects.partial.PartialBuiltins.getNewPartialArgs; |
| 44 | +import static com.oracle.graal.python.nodes.BuiltinNames.PARTIAL; |
| 45 | +import static com.oracle.graal.python.nodes.ErrorMessages.REDUCE_EMPTY_SEQ; |
| 46 | +import static com.oracle.graal.python.nodes.ErrorMessages.S_ARG_MUST_BE_CALLABLE; |
| 47 | +import static com.oracle.graal.python.nodes.ErrorMessages.S_ARG_N_MUST_SUPPORT_ITERATION; |
| 48 | +import static com.oracle.graal.python.nodes.ErrorMessages.TYPE_S_TAKES_AT_LEAST_ONE_ARGUMENT; |
| 49 | +import static com.oracle.truffle.api.nodes.LoopNode.reportLoopCount; |
| 50 | + |
44 | 51 | import java.util.List;
|
45 | 52 |
|
| 53 | +import com.oracle.graal.python.builtins.Builtin; |
46 | 54 | import com.oracle.graal.python.builtins.CoreFunctions;
|
| 55 | +import com.oracle.graal.python.builtins.PythonBuiltinClassType; |
47 | 56 | import com.oracle.graal.python.builtins.PythonBuiltins;
|
| 57 | +import com.oracle.graal.python.builtins.objects.PNone; |
| 58 | +import com.oracle.graal.python.builtins.objects.common.HashingStorage; |
| 59 | +import com.oracle.graal.python.builtins.objects.common.HashingStorageLibrary; |
| 60 | +import com.oracle.graal.python.builtins.objects.dict.PDict; |
| 61 | +import com.oracle.graal.python.builtins.objects.function.PKeyword; |
| 62 | +import com.oracle.graal.python.builtins.objects.partial.PPartial; |
| 63 | +import com.oracle.graal.python.lib.PyCallableCheckNode; |
| 64 | +import com.oracle.graal.python.lib.PyObjectGetIter; |
| 65 | +import com.oracle.graal.python.nodes.call.CallNode; |
| 66 | +import com.oracle.graal.python.nodes.control.GetNextNode; |
48 | 67 | import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
|
| 68 | +import com.oracle.graal.python.nodes.function.PythonBuiltinNode; |
| 69 | +import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode; |
| 70 | +import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode; |
| 71 | +import com.oracle.graal.python.nodes.object.GetDictIfExistsNode; |
| 72 | +import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile; |
| 73 | +import com.oracle.graal.python.runtime.exception.PException; |
| 74 | +import com.oracle.graal.python.util.PythonUtils; |
| 75 | +import com.oracle.truffle.api.dsl.Cached; |
| 76 | +import com.oracle.truffle.api.dsl.GenerateNodeFactory; |
49 | 77 | import com.oracle.truffle.api.dsl.NodeFactory;
|
| 78 | +import com.oracle.truffle.api.dsl.Specialization; |
| 79 | +import com.oracle.truffle.api.frame.VirtualFrame; |
| 80 | +import com.oracle.truffle.api.library.CachedLibrary; |
| 81 | +import com.oracle.truffle.api.profiles.ConditionProfile; |
50 | 82 |
|
51 | 83 | @CoreFunctions(defineModule = "_functools")
|
52 | 84 | public class FunctoolsModuleBuiltins extends PythonBuiltins {
|
53 | 85 | @Override
|
54 | 86 | protected List<? extends NodeFactory<? extends PythonBuiltinBaseNode>> getNodeFactories() {
|
55 |
| - return new ArrayList<>(); |
| 87 | + return FunctoolsModuleBuiltinsFactory.getFactories(); |
| 88 | + } |
| 89 | + |
| 90 | + // functools.reduce(function, iterable[, initializer]) |
| 91 | + @Builtin(name = "reduce", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 3, doc = "reduce(function, sequence[, initial]) -> value\n" + |
| 92 | + "\n" + |
| 93 | + "Apply a function of two arguments cumulatively to the items of a sequence,\n" + |
| 94 | + "from left to right, so as to reduce the sequence to a single value.\n" + |
| 95 | + "For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates\n" + |
| 96 | + "((((1+2)+3)+4)+5). If initial is present, it is placed before the items\n" + |
| 97 | + "of the sequence in the calculation, and serves as a default when the\n" + |
| 98 | + "sequence is empty.") |
| 99 | + @GenerateNodeFactory |
| 100 | + public abstract static class ReduceNode extends PythonTernaryBuiltinNode { |
| 101 | + @Specialization(guards = "isNoValue(initial)") |
| 102 | + Object doReduceNoInitial(VirtualFrame frame, Object function, Object sequence, @SuppressWarnings("unused") PNone initial, |
| 103 | + @Cached PyObjectGetIter getIter, |
| 104 | + @Cached GetNextNode nextNode, |
| 105 | + @Cached CallNode callNode, |
| 106 | + @Cached IsBuiltinClassProfile stopIterProfile, |
| 107 | + @Cached IsBuiltinClassProfile typeError) { |
| 108 | + return doReduce(frame, function, sequence, null, getIter, nextNode, callNode, stopIterProfile, typeError); |
| 109 | + } |
| 110 | + |
| 111 | + @Specialization(guards = "!isNoValue(initial)") |
| 112 | + Object doReduce(VirtualFrame frame, Object function, Object sequence, Object initial, |
| 113 | + @Cached PyObjectGetIter getIter, |
| 114 | + @Cached GetNextNode nextNode, |
| 115 | + @Cached CallNode callNode, |
| 116 | + @Cached IsBuiltinClassProfile stopIterProfile, |
| 117 | + @Cached IsBuiltinClassProfile typeError) { |
| 118 | + Object seqIterator, result = initial; |
| 119 | + try { |
| 120 | + seqIterator = getIter.execute(frame, sequence); |
| 121 | + } catch (PException pe) { |
| 122 | + pe.expectTypeError(typeError); |
| 123 | + throw raise(PythonBuiltinClassType.TypeError, S_ARG_N_MUST_SUPPORT_ITERATION, "reduce()", 2); |
| 124 | + } |
| 125 | + |
| 126 | + Object[] args = new Object[2]; |
| 127 | + |
| 128 | + int count = 0; |
| 129 | + while (true) { |
| 130 | + Object op2; |
| 131 | + try { |
| 132 | + op2 = nextNode.execute(frame, seqIterator); |
| 133 | + if (result == null) { |
| 134 | + result = op2; |
| 135 | + } else { |
| 136 | + // Update the args tuple in-place |
| 137 | + args[0] = result; |
| 138 | + args[1] = op2; |
| 139 | + result = callNode.execute(frame, function, args); |
| 140 | + } |
| 141 | + count++; |
| 142 | + } catch (PException e) { |
| 143 | + e.expectStopIteration(stopIterProfile); |
| 144 | + break; |
| 145 | + } |
| 146 | + } |
| 147 | + reportLoopCount(this, count >= 0 ? count : Integer.MAX_VALUE); |
| 148 | + |
| 149 | + if (result == null) { |
| 150 | + throw raise(PythonBuiltinClassType.TypeError, REDUCE_EMPTY_SEQ); |
| 151 | + } |
| 152 | + |
| 153 | + return result; |
| 154 | + } |
| 155 | + } |
| 156 | + |
| 157 | + // functools.cmp_to_key(func) |
| 158 | + @Builtin(name = "cmp_to_key", minNumOfPositionalArgs = 1, parameterNames = {"mycmp"}, doc = "Convert a cmp= function into a key= function.") |
| 159 | + @GenerateNodeFactory |
| 160 | + public abstract static class CmpToKeyNode extends PythonUnaryBuiltinNode { |
| 161 | + @Specialization |
| 162 | + Object doConvert(Object myCmp) { |
| 163 | + return factory().createKeyWrapper(myCmp); |
| 164 | + } |
| 165 | + } |
| 166 | + |
| 167 | + // functools.partial(func, /, *args, **keywords) |
| 168 | + @Builtin(name = PARTIAL, minNumOfPositionalArgs = 1, varArgsMarker = true, takesVarArgs = true, takesVarKeywordArgs = true, constructsClass = PythonBuiltinClassType.PPartial, doc = "partial(func, *args, **keywords) - new function with partial application\n" + |
| 169 | + "of the given arguments and keywords.\n") |
| 170 | + @GenerateNodeFactory |
| 171 | + public abstract static class PartialNode extends PythonBuiltinNode { |
| 172 | + protected boolean isPartialWithoutDict(GetDictIfExistsNode getDict, Object[] args, HashingStorageLibrary lib, boolean withKwDict) { |
| 173 | + return isPartialWithoutDict(getDict, args) && withKwDict == ((PPartial) args[0]).hasKw(lib); |
| 174 | + } |
| 175 | + |
| 176 | + protected boolean isPartialWithoutDict(GetDictIfExistsNode getDict, Object[] args) { |
| 177 | + return getDict.execute(args[0]) == null && args[0] instanceof PPartial; |
| 178 | + } |
| 179 | + |
| 180 | + protected boolean withKeywords(PKeyword[] keywords) { |
| 181 | + return keywords.length > 0; |
| 182 | + } |
| 183 | + |
| 184 | + protected boolean atLeastOneArg(Object[] args) { |
| 185 | + return args.length >= 1; |
| 186 | + } |
| 187 | + |
| 188 | + @Specialization(guards = {"atLeastOneArg(args)", "isPartialWithoutDict(getDict, args, lib, false)"}) |
| 189 | + Object createFromPartialWoDictWoKw(Object cls, Object[] args, PKeyword[] keywords, |
| 190 | + @SuppressWarnings("unused") @Cached GetDictIfExistsNode getDict, |
| 191 | + @Cached ConditionProfile hasArgsProfile, |
| 192 | + @Cached ConditionProfile hasKeywordsProfile, |
| 193 | + @SuppressWarnings("unused") @CachedLibrary(limit = "1") HashingStorageLibrary lib) { |
| 194 | + assert args[0] instanceof PPartial; |
| 195 | + final PPartial function = (PPartial) args[0]; |
| 196 | + Object[] funcArgs = getNewPartialArgs(function, args, hasArgsProfile, 1); |
| 197 | + |
| 198 | + PDict funcKwDict; |
| 199 | + if (hasKeywordsProfile.profile(keywords.length > 0)) { |
| 200 | + funcKwDict = factory().createDict(keywords); |
| 201 | + } else { |
| 202 | + funcKwDict = factory().createDict(); |
| 203 | + } |
| 204 | + |
| 205 | + return factory().createPartial(cls, function.getFn(), funcArgs, funcKwDict); |
| 206 | + } |
| 207 | + |
| 208 | + @Specialization(guards = {"atLeastOneArg(args)", "isPartialWithoutDict(getDict, args, lib, true)", "!withKeywords(keywords)"}) |
| 209 | + Object createFromPartialWoDictWKw(Object cls, Object[] args, @SuppressWarnings("unused") PKeyword[] keywords, |
| 210 | + @SuppressWarnings("unused") @Cached GetDictIfExistsNode getDict, |
| 211 | + @Cached ConditionProfile hasArgsProfile, |
| 212 | + @CachedLibrary(limit = "1") HashingStorageLibrary lib) { |
| 213 | + assert args[0] instanceof PPartial; |
| 214 | + final PPartial function = (PPartial) args[0]; |
| 215 | + Object[] funcArgs = getNewPartialArgs(function, args, hasArgsProfile, 1); |
| 216 | + return factory().createPartial(cls, function.getFn(), funcArgs, function.getKwCopy(factory(), lib)); |
| 217 | + } |
| 218 | + |
| 219 | + @Specialization(guards = {"atLeastOneArg(args)", "isPartialWithoutDict(getDict, args, lib, true)", "withKeywords(keywords)"}) |
| 220 | + Object createFromPartialWoDictWKwKw(VirtualFrame frame, Object cls, Object[] args, PKeyword[] keywords, |
| 221 | + @SuppressWarnings("unused") @Cached GetDictIfExistsNode getDict, |
| 222 | + @Cached ConditionProfile hasArgsProfile, |
| 223 | + @Cached HashingStorage.InitNode initNode, |
| 224 | + @CachedLibrary(limit = "1") HashingStorageLibrary lib) { |
| 225 | + assert args[0] instanceof PPartial; |
| 226 | + final PPartial function = (PPartial) args[0]; |
| 227 | + Object[] funcArgs = getNewPartialArgs(function, args, hasArgsProfile, 1); |
| 228 | + |
| 229 | + HashingStorage storage = function.getKw().getDictStorage(); |
| 230 | + storage = lib.addAllToOther(initNode.execute(frame, PNone.NO_VALUE, keywords), storage); |
| 231 | + |
| 232 | + return factory().createPartial(cls, function.getFn(), funcArgs, factory().createDict(storage)); |
| 233 | + } |
| 234 | + |
| 235 | + @Specialization(guards = {"atLeastOneArg(args)", "!isPartialWithoutDict(getDict, args)"}) |
| 236 | + Object createGeneric(Object cls, Object[] args, PKeyword[] keywords, |
| 237 | + @SuppressWarnings("unused") @Cached GetDictIfExistsNode getDict, |
| 238 | + @Cached ConditionProfile hasKeywordsProfile, |
| 239 | + @Cached PyCallableCheckNode callableCheckNode) { |
| 240 | + Object function = args[0]; |
| 241 | + if (!callableCheckNode.execute(function)) { |
| 242 | + throw raise(PythonBuiltinClassType.TypeError, S_ARG_MUST_BE_CALLABLE, "the first"); |
| 243 | + } |
| 244 | + |
| 245 | + final Object[] funcArgs = PythonUtils.arrayCopyOfRange(args, 1, args.length); |
| 246 | + PDict funcKwDict; |
| 247 | + if (hasKeywordsProfile.profile(keywords.length > 0)) { |
| 248 | + funcKwDict = factory().createDict(keywords); |
| 249 | + } else { |
| 250 | + funcKwDict = factory().createDict(); |
| 251 | + } |
| 252 | + return factory().createPartial(cls, function, funcArgs, funcKwDict); |
| 253 | + } |
| 254 | + |
| 255 | + @Specialization(guards = "!atLeastOneArg(args)") |
| 256 | + @SuppressWarnings("unused") |
| 257 | + Object noCallable(Object cls, Object[] args, PKeyword[] keywords) { |
| 258 | + throw raise(PythonBuiltinClassType.TypeError, TYPE_S_TAKES_AT_LEAST_ONE_ARGUMENT, "partial"); |
| 259 | + } |
56 | 260 | }
|
57 | 261 | }
|
0 commit comments