Skip to content

Commit 3e1b1a3

Browse files
janselpytorchmergebot
authored andcommitted
Revert "[inductor] Fix issue with scalar arg handling" (pytorch#163737)
This reverts commit a8cd437. See pytorch#163481 (comment) This PR might also cause issues with cudagraphs. Pull Request resolved: pytorch#163737 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#163386, pytorch#163398, pytorch#163387, pytorch#163414, pytorch#163415, pytorch#163419, pytorch#163434, pytorch#163393, pytorch#163412, pytorch#163422, pytorch#163481, pytorch#163520, pytorch#163482
1 parent 2390d34 commit 3e1b1a3

File tree

6 files changed

+3
-104
lines changed

6 files changed

+3
-104
lines changed

test/inductor/test_cpu_repro.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -485,17 +485,6 @@ def forward(self, x):
485485
example_inputs = (torch.rand(1, 10),)
486486
self.common(Model(), example_inputs)
487487

488-
@torch._dynamo.config.patch(capture_scalar_outputs=True)
489-
def test_fill_diagonal_item_scalar_cpu(self):
490-
def fn():
491-
x = torch.ones(3, 3)
492-
x.fill_diagonal_(0)
493-
return x.sum().item()
494-
495-
compiled = torch.compile(fn, backend="inductor", fullgraph=True)
496-
eager = fn()
497-
self.assertEqual(compiled(), eager)
498-
499488
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
500489
@patch("torch.cuda.is_available", lambda: False)
501490
def test_linear_packed(self):

torch/_inductor/codecache.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2864,8 +2864,6 @@ def _worker_compile_cpp(
28642864
# Customized Python binding for cpp kernels
28652865
@clear_on_fresh_cache
28662866
class CppPythonBindingsCodeCache(CppCodeCache):
2867-
"""Compile and cache CPU C++ kernels together with lightweight Python bindings."""
2868-
28692867
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
28702868
cache_clear = staticmethod(cache.clear)
28712869
cpp_compile_command_flags = {
@@ -2875,28 +2873,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
28752873
}
28762874
entry_function = "kernel"
28772875
call_entry_function = "kernel({}); Py_RETURN_NONE;"
2878-
extra_parse_arg = textwrap.dedent(
2879-
"""
2880-
template <> inline double parse_arg<double>(PyObject* args, size_t n) {{
2881-
auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n));
2882-
if(unlikely(result == -1.0 && PyErr_Occurred()))
2883-
throw std::runtime_error("expected float arg");
2884-
return result;
2885-
}}
2886-
template <> inline float parse_arg<float>(PyObject* args, size_t n) {{
2887-
auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n));
2888-
if(unlikely(result == -1.0 && PyErr_Occurred()))
2889-
throw std::runtime_error("expected float arg");
2890-
return static_cast<float>(result);
2891-
}}
2892-
template <> inline bool parse_arg<bool>(PyObject* args, size_t n) {{
2893-
int result = PyObject_IsTrue(PyTuple_GET_ITEM(args, n));
2894-
if(unlikely(result == -1 && PyErr_Occurred()))
2895-
throw std::runtime_error("expected bool arg");
2896-
return result;
2897-
}}
2898-
"""
2899-
)
2876+
extra_parse_arg = ""
29002877
suffix_template = textwrap.dedent(
29012878
"""
29022879
// Python bindings to call {entry_func}():

torch/_inductor/codegen/common.py

Lines changed: 2 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,6 @@ def alias_of(self) -> Optional[str]:
290290
return None
291291

292292

293-
@dataclasses.dataclass
294-
class ScalarArg:
295-
name: str
296-
dtype: torch.dtype
297-
298-
299293
@dataclasses.dataclass
300294
class ConstexprArg:
301295
name: str
@@ -317,14 +311,7 @@ class DeviceCodegen:
317311
fx_wrapper_codegen: Optional[WrapperConstructor] = None
318312

319313

320-
KernelArgType = Union[
321-
WorkspaceArg,
322-
TensorArg,
323-
SizeArg,
324-
ScalarArg,
325-
TMADescriptorArg,
326-
ConstexprArg,
327-
]
314+
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg]
328315

329316
device_codegens: dict[str, DeviceCodegen] = {}
330317

@@ -1480,7 +1467,6 @@ def __init__(self) -> None:
14801467
self.output_buffers: dict[str, Union[str, RemovedArg]] = {}
14811468
self.inplace_buffers: dict[str, Union[InplacedBuffer, RemovedArg]] = {}
14821469
self.sizevars: dict[sympy.Expr, str] = {}
1483-
self.scalar_vars: dict[sympy.Symbol, tuple[str, torch.dtype]] = {}
14841470
self.workspace_args: list[WorkspaceArg] = []
14851471

14861472
def __repr__(self) -> str:
@@ -1493,7 +1479,6 @@ def __repr__(self) -> str:
14931479
self.output_buffers,
14941480
self.inplace_buffers,
14951481
self.sizevars,
1496-
self.scalar_vars,
14971482
],
14981483
)
14991484
)
@@ -1641,27 +1626,9 @@ def size(self, name: sympy.Symbol) -> str:
16411626
return "seed"
16421627
return self._lookup("ks", self.sizevars, name)
16431628

1644-
def scalar(self, name: sympy.Symbol, dtype: torch.dtype) -> str:
1645-
assert isinstance(name, sympy.Symbol), (type(name), name)
1646-
if name in self.scalar_vars:
1647-
inner, existing_dtype = self.scalar_vars[name]
1648-
if existing_dtype != dtype:
1649-
try:
1650-
promoted = torch.promote_types(existing_dtype, dtype)
1651-
except TypeError:
1652-
promoted = dtype
1653-
self.scalar_vars[name] = (inner, promoted)
1654-
return self.scalar_vars[name][0]
1655-
inner = f"kscalar{len(self.scalar_vars)}"
1656-
self.scalar_vars[name] = (inner, dtype)
1657-
return inner
1658-
16591629
def call_names(self) -> Iterator[str]:
16601630
return chain(
1661-
self.input_buffers.keys(),
1662-
self.output_buffers.keys(),
1663-
self.sizevars.keys(),
1664-
self.scalar_vars.keys(),
1631+
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
16651632
)
16661633

16671634
def arg_name(self, name: str) -> Optional[str]:
@@ -1682,9 +1649,6 @@ def wrap_ptr_arg(self, buf: str, dtype: torch.dtype) -> str:
16821649
def wrap_size_arg(self, size: SymbolLike) -> str:
16831650
return str(size)
16841651

1685-
def wrap_scalar_arg(self, scalar: sympy.Symbol) -> str:
1686-
return str(scalar)
1687-
16881652
def cpp_argdefs(
16891653
self, dtype_to_cpp_type: Optional[dict[torch.dtype, str]] = None
16901654
) -> tuple[list[str], list[str], list[str]]:
@@ -1730,11 +1694,6 @@ def cpp_argdefs(
17301694
arg_types.append(f"const {INDEX_TYPE}")
17311695
if V.graph.wrapper_code:
17321696
V.graph.wrapper_code.ensure_size_computed(outer)
1733-
for outer, (inner, dtype) in self.scalar_vars.items():
1734-
cpp_dtype = dtype_to_cpp_type[dtype]
1735-
arg_defs.append(f"const {cpp_dtype} {inner}")
1736-
call_args.append(self.wrap_scalar_arg(outer))
1737-
arg_types.append(f"const {cpp_dtype}")
17381697
assert not self.workspace_args, "Workspace not supported on CPU "
17391698
return arg_defs, call_args, arg_types
17401699

@@ -1780,11 +1739,6 @@ def python_argdefs(
17801739
precompile_args.append(SizeArg(inner, outer))
17811740
if V.graph.wrapper_code:
17821741
V.graph.wrapper_code.ensure_size_computed(outer)
1783-
for outer, (inner, dtype) in self.scalar_vars.items():
1784-
arg_defs.append(ArgName(inner))
1785-
call_args.append(self.wrap_scalar_arg(outer))
1786-
arg_types.append(dtype)
1787-
precompile_args.append(ScalarArg(inner, dtype))
17881742
for arg in self.workspace_args:
17891743
arg_defs.append(ArgName(arg.inner_name))
17901744
call_args.append(arg.outer_name)
@@ -2339,10 +2293,6 @@ def rename_indexing(
23392293
),
23402294
)
23412295
}
2342-
for x in sorted_symbols:
2343-
if symbol_is_type(x, (SymT.FLOAT, SymT.UNBACKED_FLOAT)):
2344-
dtype = V.graph.get_dynamic_scalar_dtype(x)
2345-
replacements[x] = self.args.scalar(x, dtype)
23462296
return sympy_subs(index, replacements)
23472297

23482298
def create_cse_var(self, *args: Any, **kwargs: Any) -> CSEVariable:

torch/_inductor/codegen/triton_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
ArgName,
1414
ConstexprArg,
1515
KernelArgType,
16-
ScalarArg,
1716
SizeArg,
1817
TensorArg,
1918
TMADescriptorArg,
@@ -92,9 +91,6 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
9291
raise NotImplementedError(f"unhandled size_dtype {size_dtype}")
9392
if isinstance(arg, WorkspaceArg):
9493
return _type_of(arg.dtype)
95-
if isinstance(arg, ScalarArg):
96-
typ = _type_of(arg.dtype)
97-
return typ.removeprefix("*")
9894
if isinstance(arg, TMADescriptorArg):
9995
if arg.api_type == "experimental":
10096
return "nvTmaDesc"

torch/_inductor/graph.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,6 @@ def __init__(
385385
const_module.device_idxs if const_module else OrderedSet()
386386
)
387387
self.device_type = "cpu"
388-
self.dynamic_scalar_dtypes: dict[sympy.Symbol, torch.dtype] = {}
389388

390389
# Inplace padding may require Inductor to allocate slightly larger
391390
# tensor for padding.
@@ -950,17 +949,6 @@ def get_dtype(self, buffer_name: str) -> torch.dtype:
950949
return self.get_dtype(m.group(1))
951950
raise KeyError(f"could not find {buffer_name}")
952951

953-
def register_dynamic_scalar_dtype(
954-
self, sym: sympy.Symbol, dtype: torch.dtype
955-
) -> None:
956-
existing = self.dynamic_scalar_dtypes.get(sym)
957-
if existing is not None and existing != dtype:
958-
dtype = torch.promote_types(existing, dtype)
959-
self.dynamic_scalar_dtypes[sym] = dtype
960-
961-
def get_dynamic_scalar_dtype(self, sym: sympy.Symbol) -> torch.dtype:
962-
return self.dynamic_scalar_dtypes.get(sym, torch.float64)
963-
964952
def get_numel(self, buffer_name: str) -> Union[int, Expr]:
965953
if buffer_name in self.constants:
966954
return self.constants[buffer_name].numel()

torch/_inductor/lowering.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3236,7 +3236,6 @@ def _local_scalar_dense(data):
32363236
buffer = ir.DynamicScalar(binding_sym, keypath, data)
32373237
buffer.name = V.graph.register_buffer(buffer)
32383238
V.graph.register_operation(buffer)
3239-
V.graph.register_dynamic_scalar_dtype(binding_sym, data.get_dtype())
32403239
# NB: the replaced expr is OK to use directly downstream, we want
32413240
# simplifications in this case!
32423241
val = V.graph.current_node.meta["val"]

0 commit comments

Comments
 (0)