Skip to content

Commit 7b806a8

Browse files
Revert "[inductor][dynamo] Include operator name in size/stride/alignment assertion (pytorch#152353)"
This reverts commit 9357635. Reverted pytorch#152353 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to fail an inductor test in trunk ([comment](pytorch#152353 (comment)))
1 parent d291fa8 commit 7b806a8

File tree

5 files changed

+19
-160
lines changed

5 files changed

+19
-160
lines changed

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -231,14 +231,6 @@ def _test_code_common(
231231
torch.compile(mod, fullgraph=True, dynamic=check_dynamic),
232232
*clone_inputs,
233233
)
234-
assert_keywords = ["assert_size_stride", "assert_alignment"]
235-
filtered_lines = [
236-
line
237-
for line in source_code.splitlines()
238-
if not any(assert_key in line for assert_key in assert_keywords)
239-
]
240-
source_code = "\n".join(filtered_lines)
241-
242234
for op in include_ops:
243235
self.assertIn(op, source_code)
244236
if num_include_ops is not None:

test/inductor/test_torchinductor.py

Lines changed: 7 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import torch._dynamo.config as dynamo_config
3131
import torch._inductor.aoti_eager
3232
import torch.nn as nn
33-
from torch._C._dynamo.guards import assert_alignment, assert_size_stride
3433
from torch._dispatch.python import enable_python_dispatcher
3534
from torch._dynamo.debug_utils import aot_graph_input_parser
3635
from torch._dynamo.device_interface import get_interface_for_device
@@ -1411,10 +1410,9 @@ def fn(a, b):
14111410
)
14121411
_, code = run_and_get_code(fn, x, y)
14131412
code = " ".join(code)
1414-
if config.cpp_wrapper:
1415-
self.assertEqual(code.count("view_dtype"), 3)
1416-
else:
1417-
self.assertEqual(code.count("aten.view"), 9)
1413+
self.assertEqual(
1414+
code.count("view_dtype" if config.cpp_wrapper else "aten.view"), 3
1415+
)
14181416

14191417
def test_add_complex5(self):
14201418
def fn(a, b, alpha):
@@ -11884,80 +11882,6 @@ def fn(x):
1188411882
check_lowp=False,
1188511883
)
1188611884

11887-
@requires_gpu()
11888-
@skip_if_not_triton
11889-
@config.patch(implicit_fallbacks=True)
11890-
def test_generated_code_has_size_stride_assert(self):
11891-
def foo(x):
11892-
return 3 * x
11893-
11894-
def foo_meta(x):
11895-
return torch.empty_like(x)
11896-
11897-
define_custom_op_for_test("foo", foo, foo_meta)
11898-
11899-
def fn(x):
11900-
a = torch.nn.functional.relu(x)
11901-
b = torch.ops.test.foo(a)
11902-
return b
11903-
11904-
a = torch.randn((16, 32), device=self.device)
11905-
11906-
_, code = run_and_get_code(
11907-
torch.compile(fn),
11908-
a,
11909-
)
11910-
if not is_dynamic_shape_enabled():
11911-
FileCheck().check(
11912-
"assert_size_stride(buf2, (16, 32), (32, 1), 'torch.ops.test.foo.default')"
11913-
).run(code[0])
11914-
11915-
@requires_gpu()
11916-
@skip_if_not_triton
11917-
@config.patch(implicit_fallbacks=True)
11918-
def test_generated_code_has_alignment_assert(self):
11919-
def foo(x):
11920-
return 3 * x
11921-
11922-
def foo_meta(x):
11923-
return torch.empty_like(x)
11924-
11925-
define_custom_op_for_test("foo", foo, foo_meta)
11926-
11927-
def fn(x):
11928-
a = torch.nn.functional.relu(x)
11929-
b = torch.ops.test.foo(a)
11930-
return b
11931-
11932-
a = torch.randn((16, 32), device=self.device)
11933-
11934-
_, code = run_and_get_code(
11935-
torch.compile(fn),
11936-
a,
11937-
)
11938-
if not is_dynamic_shape_enabled():
11939-
FileCheck().check(
11940-
"assert_alignment(buf2, 16, 'torch.ops.test.foo.default')"
11941-
).run(code[0])
11942-
11943-
def test_assert_size_stride_op_name_pass(self):
11944-
tensor = torch.empty((16, 32))
11945-
assert_size_stride(tensor, (16, 32), (32, 1), "torch.ops.dummy.op_name")
11946-
11947-
def test_assert_size_stride_op_name_fail(self):
11948-
tensor = torch.empty((16, 32))
11949-
with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"):
11950-
assert_size_stride(tensor, (32, 64), (32, 1), "torch.ops.dummy.op_name")
11951-
11952-
def test_assert_alignment_op_name_pass(self):
11953-
tensor = torch.empty((16, 32))
11954-
assert_alignment(tensor, 16, "torch.ops.dummy.op_name")
11955-
11956-
def test_assert_alignment_op_name_fail(self):
11957-
tensor = torch.empty((16, 32))
11958-
with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"):
11959-
assert_alignment(tensor, 0, "torch.ops.dummy.op_name")
11960-
1196111885
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
1196211886
@torch._inductor.config.patch(implicit_fallbacks=True)
1196311887
def test_custom_op_unbacked_symints(self):
@@ -13089,12 +13013,12 @@ def f(x):
1308913013
code = run_and_get_triton_code(f, x)
1309013014

1309113015
if is_dynamic_shape_enabled():
13092-
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1)").check(
13093-
"assert_size_stride(buf2, (s77, s27), (s27, 1)"
13016+
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check(
13017+
"assert_size_stride(buf2, (s77, s27), (s27, 1))"
1309413018
).run(code)
1309513019
else:
13096-
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1)").check(
13097-
"assert_size_stride(buf2, (16, 32), (32, 1)"
13020+
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check(
13021+
"assert_size_stride(buf2, (16, 32), (32, 1))"
1309813022
).run(code)
1309913023

1310013024
@requires_cuda

torch/_C/_dynamo/guards.pyi

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,6 @@ def assert_size_stride(
176176
item: torch.Tensor,
177177
size: torch.types._size,
178178
stride: torch.types._size,
179-
op_name: str | None = None,
180-
): ...
181-
def assert_alignment(
182-
item: torch.Tensor,
183-
alignment: int,
184-
op_name: str | None = None,
185179
): ...
186180
def check_obj_id(obj: object, expected: int) -> bool: ...
187181
def check_type_id(obj: object, expected: int) -> bool: ...

torch/_inductor/ir.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5736,42 +5736,26 @@ def codegen_kwargs(self, skip_out=False): # type: ignore[no-untyped-def]
57365736
]
57375737
return kwargs
57385738

5739-
def get_op_name(self) -> str:
5740-
if self.fx_node is not None:
5741-
target = self.fx_node.target
5742-
op_namespace = getattr(target, "__module__", "unknown_namespace")
5743-
op_namespace = op_namespace.replace("._ops.", ".ops.")
5744-
op_namespace = op_namespace.rsplit(".", 1)[0]
5745-
op_name = f"{op_namespace}.{target}"
5746-
else:
5747-
op_name = "unknown_op"
5748-
return op_name
5749-
57505739
def codegen_size_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def]
57515740
if config.size_asserts and not V.graph.cpp_wrapper:
57525741
# comparing strides for 0 size tensor is tricky. Ignore them for now.
57535742
if sympy_product(self.get_size()) == 0:
57545743
return
57555744
size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
57565745
stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
5757-
op_name = self.get_op_name()
5746+
57585747
wrapper.writeline(
5759-
f"assert_size_stride({self.get_name()}, {size}, {stride}, {op_name!r})"
5748+
f"assert_size_stride({self.get_name()}, {size}, {stride})"
57605749
)
57615750

57625751
def codegen_alignment_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def]
57635752
if config.alignment_asserts and not V.graph.cpp_wrapper:
57645753
name = self.get_name()
57655754
aligned = name not in V.graph.unaligned_buffers
5766-
op_name = self.get_op_name()
57675755
if aligned:
5768-
wrapper.writeline(
5769-
f"assert_alignment({name}, {GPU_ALIGN_BYTES}, {op_name!r})"
5770-
)
5756+
wrapper.writeline(f"assert_alignment({name}, {GPU_ALIGN_BYTES})")
57715757
else:
5772-
wrapper.writeline(
5773-
f"# buffer {name} (op: {op_name}) is assumed to be not aligned"
5774-
)
5758+
wrapper.writeline(f"# buffer {name} is assumed to be not aligned")
57755759

57765760
def get_group_stride(self): # type: ignore[no-untyped-def]
57775761
"""

torch/csrc/dynamo/guards.cpp

Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -844,38 +844,21 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
844844
PyObject* item = nullptr;
845845
PyObject* size = nullptr;
846846
PyObject* stride = nullptr;
847-
const char* op_name = nullptr;
848-
849-
if (!PyArg_ParseTuple(args, "OOO|s", &item, &size, &stride, &op_name)) {
847+
if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) {
850848
return nullptr;
851849
}
852850
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
853-
std::stringstream msg;
854-
msg << "expected Tensor()";
855-
if (op_name) {
856-
msg << " for op: " << op_name;
857-
}
858-
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
851+
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
859852
return nullptr;
860853
}
861854
if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) {
862-
std::stringstream msg;
863-
msg << "expected tuple()";
864-
if (op_name) {
865-
msg << " for op: " << op_name;
866-
}
867-
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
855+
PyErr_SetString(PyExc_TypeError, "expected tuple()");
868856
return nullptr;
869857
}
870858
at::Tensor tensor = THPVariable_Unpack(item);
871859
int64_t ndim = tensor.ndimension();
872860
if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) {
873-
std::stringstream msg;
874-
msg << "wrong number of dimensions" << ndim;
875-
if (op_name) {
876-
msg << " for op: " << op_name;
877-
}
878-
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
861+
PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
879862
return nullptr;
880863
}
881864

@@ -904,9 +887,6 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
904887
}
905888

906889
if (num_errors) {
907-
if (op_name) {
908-
msg << "\nError in op: " << op_name;
909-
}
910890
msg << "\nThis error most often comes from a incorrect fake (aka meta) kernel for a custom op.";
911891
msg << "\nUse torch.library.opcheck to test your custom op.";
912892
msg << "\nSee https://pytorch.org/docs/stable/library.html#torch.library.opcheck";
@@ -924,27 +904,15 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) {
924904
*/
925905
PyObject* item = nullptr;
926906
unsigned long alignment = 0;
927-
const char* op_name = nullptr;
928-
929-
if (!PyArg_ParseTuple(args, "Ok|s", &item, &alignment, &op_name)) {
907+
if (!PyArg_ParseTuple(args, "Ok", &item, &alignment)) {
930908
return nullptr;
931909
}
932910
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
933-
std::stringstream msg;
934-
msg << "expected Tensor()";
935-
if (op_name) {
936-
msg << " for op: " << op_name;
937-
}
938-
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
911+
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
939912
return nullptr;
940913
}
941914
if (alignment == 0) {
942-
std::stringstream msg;
943-
msg << "alignment cannot be 0";
944-
if (op_name) {
945-
msg << " in op: " << op_name;
946-
}
947-
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
915+
PyErr_SetString(PyExc_AssertionError, "alignment can not be 0");
948916
return nullptr;
949917
}
950918

@@ -954,10 +922,7 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) {
954922
size_t itemsize = tensor.itemsize();
955923
if (storage_offset * itemsize % alignment != 0) {
956924
std::stringstream msg;
957-
if (op_name) {
958-
msg << "\nError in op: " << op_name;
959-
}
960-
msg << "\nExpect the tensor to be " << alignment
925+
msg << "Expect the tensor to be " << alignment
961926
<< " bytes aligned. Fail due to storage_offset=" << storage_offset
962927
<< " itemsize=" << itemsize;
963928
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());

0 commit comments

Comments
 (0)