@@ -290,12 +290,6 @@ def alias_of(self) -> Optional[str]:
290290 return None
291291
292292
293- @dataclasses .dataclass
294- class ScalarArg :
295- name : str
296- dtype : torch .dtype
297-
298-
299293@dataclasses .dataclass
300294class ConstexprArg :
301295 name : str
@@ -317,14 +311,7 @@ class DeviceCodegen:
317311 fx_wrapper_codegen : Optional [WrapperConstructor ] = None
318312
319313
320- KernelArgType = Union [
321- WorkspaceArg ,
322- TensorArg ,
323- SizeArg ,
324- ScalarArg ,
325- TMADescriptorArg ,
326- ConstexprArg ,
327- ]
314+ KernelArgType = Union [WorkspaceArg , TensorArg , SizeArg , TMADescriptorArg , ConstexprArg ]
328315
329316device_codegens : dict [str , DeviceCodegen ] = {}
330317
@@ -1480,7 +1467,6 @@ def __init__(self) -> None:
14801467 self .output_buffers : dict [str , Union [str , RemovedArg ]] = {}
14811468 self .inplace_buffers : dict [str , Union [InplacedBuffer , RemovedArg ]] = {}
14821469 self .sizevars : dict [sympy .Expr , str ] = {}
1483- self .scalar_vars : dict [sympy .Symbol , tuple [str , torch .dtype ]] = {}
14841470 self .workspace_args : list [WorkspaceArg ] = []
14851471
14861472 def __repr__ (self ) -> str :
@@ -1493,7 +1479,6 @@ def __repr__(self) -> str:
14931479 self .output_buffers ,
14941480 self .inplace_buffers ,
14951481 self .sizevars ,
1496- self .scalar_vars ,
14971482 ],
14981483 )
14991484 )
@@ -1641,27 +1626,9 @@ def size(self, name: sympy.Symbol) -> str:
16411626 return "seed"
16421627 return self ._lookup ("ks" , self .sizevars , name )
16431628
1644- def scalar (self , name : sympy .Symbol , dtype : torch .dtype ) -> str :
1645- assert isinstance (name , sympy .Symbol ), (type (name ), name )
1646- if name in self .scalar_vars :
1647- inner , existing_dtype = self .scalar_vars [name ]
1648- if existing_dtype != dtype :
1649- try :
1650- promoted = torch .promote_types (existing_dtype , dtype )
1651- except TypeError :
1652- promoted = dtype
1653- self .scalar_vars [name ] = (inner , promoted )
1654- return self .scalar_vars [name ][0 ]
1655- inner = f"kscalar{ len (self .scalar_vars )} "
1656- self .scalar_vars [name ] = (inner , dtype )
1657- return inner
1658-
16591629 def call_names (self ) -> Iterator [str ]:
16601630 return chain (
1661- self .input_buffers .keys (),
1662- self .output_buffers .keys (),
1663- self .sizevars .keys (),
1664- self .scalar_vars .keys (),
1631+ self .input_buffers .keys (), self .output_buffers .keys (), self .sizevars .keys ()
16651632 )
16661633
16671634 def arg_name (self , name : str ) -> Optional [str ]:
@@ -1682,9 +1649,6 @@ def wrap_ptr_arg(self, buf: str, dtype: torch.dtype) -> str:
16821649 def wrap_size_arg (self , size : SymbolLike ) -> str :
16831650 return str (size )
16841651
1685- def wrap_scalar_arg (self , scalar : sympy .Symbol ) -> str :
1686- return str (scalar )
1687-
16881652 def cpp_argdefs (
16891653 self , dtype_to_cpp_type : Optional [dict [torch .dtype , str ]] = None
16901654 ) -> tuple [list [str ], list [str ], list [str ]]:
@@ -1730,11 +1694,6 @@ def cpp_argdefs(
17301694 arg_types .append (f"const { INDEX_TYPE } " )
17311695 if V .graph .wrapper_code :
17321696 V .graph .wrapper_code .ensure_size_computed (outer )
1733- for outer , (inner , dtype ) in self .scalar_vars .items ():
1734- cpp_dtype = dtype_to_cpp_type [dtype ]
1735- arg_defs .append (f"const { cpp_dtype } { inner } " )
1736- call_args .append (self .wrap_scalar_arg (outer ))
1737- arg_types .append (f"const { cpp_dtype } " )
17381697 assert not self .workspace_args , "Workspace not supported on CPU "
17391698 return arg_defs , call_args , arg_types
17401699
@@ -1780,11 +1739,6 @@ def python_argdefs(
17801739 precompile_args .append (SizeArg (inner , outer ))
17811740 if V .graph .wrapper_code :
17821741 V .graph .wrapper_code .ensure_size_computed (outer )
1783- for outer , (inner , dtype ) in self .scalar_vars .items ():
1784- arg_defs .append (ArgName (inner ))
1785- call_args .append (self .wrap_scalar_arg (outer ))
1786- arg_types .append (dtype )
1787- precompile_args .append (ScalarArg (inner , dtype ))
17881742 for arg in self .workspace_args :
17891743 arg_defs .append (ArgName (arg .inner_name ))
17901744 call_args .append (arg .outer_name )
@@ -2339,10 +2293,6 @@ def rename_indexing(
23392293 ),
23402294 )
23412295 }
2342- for x in sorted_symbols :
2343- if symbol_is_type (x , (SymT .FLOAT , SymT .UNBACKED_FLOAT )):
2344- dtype = V .graph .get_dynamic_scalar_dtype (x )
2345- replacements [x ] = self .args .scalar (x , dtype )
23462296 return sympy_subs (index , replacements )
23472297
23482298 def create_cse_var (self , * args : Any , ** kwargs : Any ) -> CSEVariable :
0 commit comments