@@ -58,6 +58,8 @@ class ValueRef:
5858 src_cpp_type : str
5959 is_in : bool = False
6060 is_out : bool = False
61+ fixed_storage_type : Optional [str ] = None
62+ fixed_memory_layout : Optional [str ] = None
6163 requires_prepack : bool = False
6264 supports_prepack : bool = False
6365 # When is_dynamic_size is true, the underlying object size is not known
@@ -137,20 +139,43 @@ def __init__(
137139 if arg .name in self .suite_def .prepacked_args :
138140 supports_prepack = True
139141
142+ fixed_storage_type = None
143+ if arg .name in self .suite_def .arg_storage_types :
144+ fixed_storage_type = self .suite_def .arg_storage_types [arg .name ]
145+
146+ fixed_memory_layout = None
147+ if arg .name in self .suite_def .arg_memory_layouts :
148+ fixed_memory_layout = self .suite_def .arg_memory_layouts [arg .name ]
149+
140150 self .refs [arg .name ] = ValueRef (
141151 name = f"{ arg .name } _ref" ,
142152 src_cpp_name = arg .name ,
143153 src_cpp_type = cpp_type ,
144154 is_in = (cpp_type in InableCppType ),
155+ fixed_storage_type = fixed_storage_type ,
156+ fixed_memory_layout = fixed_memory_layout ,
145157 requires_prepack = requires_prepack ,
146158 supports_prepack = supports_prepack ,
147159 )
148160
149161 ret_type = cpp .returns_type (self .f .func .returns , symint = False ).cpp_type ()
150162 self .out = ATenArg (name = "out" , cpp_type = ret_type , default = None )
163+
164+ fixed_storage_type = None
165+ if "out" in self .suite_def .arg_storage_types :
166+ fixed_storage_type = self .suite_def .arg_storage_types ["out" ]
167+ fixed_memory_layout = None
168+ if "out" in self .suite_def .arg_memory_layouts :
169+ fixed_memory_layout = self .suite_def .arg_memory_layouts ["out" ]
170+
151171 if ret_type == AT_TENSOR :
152172 self .refs ["out" ] = ValueRef (
153- name = "out_ref" , src_cpp_name = "out" , src_cpp_type = ret_type , is_out = True
173+ name = "out_ref" ,
174+ src_cpp_name = "out" ,
175+ src_cpp_type = ret_type ,
176+ is_out = True ,
177+ fixed_storage_type = fixed_storage_type ,
178+ fixed_memory_layout = fixed_memory_layout ,
154179 )
155180 elif ret_type == TWO_TENSOR_TUPLE :
156181 self .refs ["out" ] = [
@@ -159,12 +184,24 @@ def __init__(
159184 src_cpp_name = "std::get<0>(out)" ,
160185 src_cpp_type = "at::Tensor" ,
161186 is_out = True ,
187+ fixed_storage_type = (
188+ fixed_storage_type [0 ] if fixed_storage_type else None
189+ ),
190+ fixed_memory_layout = (
191+ fixed_memory_layout [0 ] if fixed_memory_layout else None
192+ ),
162193 ),
163194 ValueRef (
164195 name = "out_ref_second" ,
165196 src_cpp_name = "std::get<1>(out)" ,
166197 src_cpp_type = "at::Tensor" ,
167198 is_out = True ,
199+ fixed_storage_type = (
200+ fixed_storage_type [1 ] if fixed_storage_type else None
201+ ),
202+ fixed_memory_layout = (
203+ fixed_memory_layout [1 ] if fixed_memory_layout else None
204+ ),
168205 ),
169206 ValueRef (
170207 name = "out_ref" ,
@@ -180,18 +217,36 @@ def __init__(
180217 src_cpp_name = "std::get<0>(out)" ,
181218 src_cpp_type = "at::Tensor" ,
182219 is_out = True ,
220+ fixed_storage_type = (
221+ fixed_storage_type [0 ] if fixed_storage_type else None
222+ ),
223+ fixed_memory_layout = (
224+ fixed_memory_layout [0 ] if fixed_memory_layout else None
225+ ),
183226 ),
184227 ValueRef (
185228 name = "out_ref_second" ,
186229 src_cpp_name = "std::get<1>(out)" ,
187230 src_cpp_type = "at::Tensor" ,
188231 is_out = True ,
232+ fixed_storage_type = (
233+ fixed_storage_type [1 ] if fixed_storage_type else None
234+ ),
235+ fixed_memory_layout = (
236+ fixed_memory_layout [1 ] if fixed_memory_layout else None
237+ ),
189238 ),
190239 ValueRef (
191240 name = "out_ref_third" ,
192241 src_cpp_name = "std::get<2>(out)" ,
193242 src_cpp_type = "at::Tensor" ,
194243 is_out = True ,
244+ fixed_storage_type = (
245+ fixed_storage_type [2 ] if fixed_storage_type else None
246+ ),
247+ fixed_memory_layout = (
248+ fixed_memory_layout [2 ] if fixed_memory_layout else None
249+ ),
195250 ),
196251 ValueRef (
197252 name = "out_ref" ,
@@ -302,7 +357,12 @@ def create_value_for( # noqa: C901
302357 ret_str += f"{ self .graph } { self .dot } "
303358 ret_str += "add_input_tensor(" if ref .is_in else "add_tensor("
304359 ret_str += f"{ ref .src_cpp_name } ->sizes().vec(), "
305- ret_str += f"from_at_scalartype({ ref .src_cpp_name } ->scalar_type())); \n "
360+ ret_str += f"from_at_scalartype({ ref .src_cpp_name } ->scalar_type()"
361+ if ref .fixed_storage_type :
362+ ret_str += f", { ref .fixed_storage_type } "
363+ if ref .fixed_memory_layout :
364+ ret_str += f", { ref .fixed_memory_layout } "
365+ ret_str += "));\n "
306366 elif prepack :
307367 ret_str += f"{ self .graph } { self .dot } "
308368 ret_str += f"add_tensorref({ ref .src_cpp_name } ->sizes().vec(), "
@@ -385,7 +445,12 @@ def create_value_for( # noqa: C901
385445 elif ref .src_cpp_type == AT_TENSOR and not prepack :
386446 ret_str += "add_input_tensor(" if ref .is_in else "add_tensor("
387447 ret_str += f"{ ref .src_cpp_name } .sizes().vec(), "
388- ret_str += f"from_at_scalartype({ ref .src_cpp_name } .scalar_type())); \n "
448+ ret_str += f"from_at_scalartype({ ref .src_cpp_name } .scalar_type())"
449+ if ref .fixed_storage_type :
450+ ret_str += f", { ref .fixed_storage_type } "
451+ if ref .fixed_memory_layout :
452+ ret_str += f", { ref .fixed_memory_layout } "
453+ ret_str += ");\n "
389454 elif ref .src_cpp_type == AT_TENSOR and prepack :
390455 ret_str += f"add_tensorref({ ref .src_cpp_name } .sizes().vec(), "
391456 ret_str += f"from_at_scalartype({ ref .src_cpp_name } .scalar_type()), "
0 commit comments