Skip to content

Commit 7519642

Browse files
committed
Avoid direct slice struct usage in torch
1 parent 91452ce commit 7519642

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

graalpython/lib-graalpython/patches/torch/torch-1.13.1.patch

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,46 @@ index 8e1ca3b1..b150ac3f 100644
269269
}
270270

271271
Py_INCREF(obj);
272+
diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp
273+
index 8c9ed1d7..5183a325 100644
274+
--- a/torch/csrc/autograd/python_variable_indexing.cpp
275+
+++ b/torch/csrc/autograd/python_variable_indexing.cpp
276+
@@ -141,26 +141,28 @@ static inline void checkUnpackSlice(
277+
}
278+
279+
static inline void recordSliceTrace(PyObject* obj) {
280+
- PySliceObject* sliceobj = (PySliceObject*)obj;
281+
- if (THPVariable_Check(sliceobj->start)) {
282+
+ PyObject* slicestart = PySlice_Start(obj);
283+
+ if (THPVariable_Check(slicestart)) {
284+
torch::jit::tracer::ArgumentStash::stashValue(
285+
std::string("start"),
286+
1,
287+
- THPVariable_Unpack(sliceobj->start),
288+
+ THPVariable_Unpack(slicestart),
289+
torch::jit::IntType::get());
290+
}
291+
- if (THPVariable_Check(sliceobj->stop)) {
292+
+ PyObject* slicestop = PySlice_Stop(obj);
293+
+ if (THPVariable_Check(slicestop)) {
294+
torch::jit::tracer::ArgumentStash::stashValue(
295+
std::string("end"),
296+
1,
297+
- THPVariable_Unpack(sliceobj->stop),
298+
+ THPVariable_Unpack(slicestop),
299+
torch::jit::IntType::get());
300+
}
301+
- if (THPVariable_Check(sliceobj->step)) {
302+
+ PyObject* slicestep = PySlice_Step(obj);
303+
+ if (THPVariable_Check(slicestep)) {
304+
torch::jit::tracer::ArgumentStash::stashValue(
305+
std::string("step"),
306+
1,
307+
- THPVariable_Unpack(sliceobj->step),
308+
+ THPVariable_Unpack(slicestep),
309+
torch::jit::IntType::get());
310+
}
311+
}
272312
diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp
273313
index 78676e2e..ecc76ea9 100644
274314
--- a/torch/csrc/jit/python/python_tracer.cpp

0 commit comments

Comments
 (0)