@@ -58,6 +58,8 @@ class ValueRef:
58
58
src_cpp_type : str
59
59
is_in : bool = False
60
60
is_out : bool = False
61
+ fixed_storage_type : Optional [str ] = None
62
+ fixed_memory_layout : Optional [str ] = None
61
63
requires_prepack : bool = False
62
64
supports_prepack : bool = False
63
65
# When is_dynamic_size is true, the underlying object size is not known
@@ -137,20 +139,43 @@ def __init__(
137
139
if arg .name in self .suite_def .prepacked_args :
138
140
supports_prepack = True
139
141
142
+ fixed_storage_type = None
143
+ if arg .name in self .suite_def .arg_storage_types :
144
+ fixed_storage_type = self .suite_def .arg_storage_types [arg .name ]
145
+
146
+ fixed_memory_layout = None
147
+ if arg .name in self .suite_def .arg_memory_layouts :
148
+ fixed_memory_layout = self .suite_def .arg_memory_layouts [arg .name ]
149
+
140
150
self .refs [arg .name ] = ValueRef (
141
151
name = f"{ arg .name } _ref" ,
142
152
src_cpp_name = arg .name ,
143
153
src_cpp_type = cpp_type ,
144
154
is_in = (cpp_type in InableCppType ),
155
+ fixed_storage_type = fixed_storage_type ,
156
+ fixed_memory_layout = fixed_memory_layout ,
145
157
requires_prepack = requires_prepack ,
146
158
supports_prepack = supports_prepack ,
147
159
)
148
160
149
161
ret_type = cpp .returns_type (self .f .func .returns , symint = False ).cpp_type ()
150
162
self .out = ATenArg (name = "out" , cpp_type = ret_type , default = None )
163
+
164
+ fixed_storage_type = None
165
+ if "out" in self .suite_def .arg_storage_types :
166
+ fixed_storage_type = self .suite_def .arg_storage_types ["out" ]
167
+ fixed_memory_layout = None
168
+ if "out" in self .suite_def .arg_memory_layouts :
169
+ fixed_memory_layout = self .suite_def .arg_memory_layouts ["out" ]
170
+
151
171
if ret_type == AT_TENSOR :
152
172
self .refs ["out" ] = ValueRef (
153
- name = "out_ref" , src_cpp_name = "out" , src_cpp_type = ret_type , is_out = True
173
+ name = "out_ref" ,
174
+ src_cpp_name = "out" ,
175
+ src_cpp_type = ret_type ,
176
+ is_out = True ,
177
+ fixed_storage_type = fixed_storage_type ,
178
+ fixed_memory_layout = fixed_memory_layout ,
154
179
)
155
180
elif ret_type == TWO_TENSOR_TUPLE :
156
181
self .refs ["out" ] = [
@@ -159,12 +184,24 @@ def __init__(
159
184
src_cpp_name = "std::get<0>(out)" ,
160
185
src_cpp_type = "at::Tensor" ,
161
186
is_out = True ,
187
+ fixed_storage_type = (
188
+ fixed_storage_type [0 ] if fixed_storage_type else None
189
+ ),
190
+ fixed_memory_layout = (
191
+ fixed_memory_layout [0 ] if fixed_memory_layout else None
192
+ ),
162
193
),
163
194
ValueRef (
164
195
name = "out_ref_second" ,
165
196
src_cpp_name = "std::get<1>(out)" ,
166
197
src_cpp_type = "at::Tensor" ,
167
198
is_out = True ,
199
+ fixed_storage_type = (
200
+ fixed_storage_type [1 ] if fixed_storage_type else None
201
+ ),
202
+ fixed_memory_layout = (
203
+ fixed_memory_layout [1 ] if fixed_memory_layout else None
204
+ ),
168
205
),
169
206
ValueRef (
170
207
name = "out_ref" ,
@@ -180,18 +217,36 @@ def __init__(
180
217
src_cpp_name = "std::get<0>(out)" ,
181
218
src_cpp_type = "at::Tensor" ,
182
219
is_out = True ,
220
+ fixed_storage_type = (
221
+ fixed_storage_type [0 ] if fixed_storage_type else None
222
+ ),
223
+ fixed_memory_layout = (
224
+ fixed_memory_layout [0 ] if fixed_memory_layout else None
225
+ ),
183
226
),
184
227
ValueRef (
185
228
name = "out_ref_second" ,
186
229
src_cpp_name = "std::get<1>(out)" ,
187
230
src_cpp_type = "at::Tensor" ,
188
231
is_out = True ,
232
+ fixed_storage_type = (
233
+ fixed_storage_type [1 ] if fixed_storage_type else None
234
+ ),
235
+ fixed_memory_layout = (
236
+ fixed_memory_layout [1 ] if fixed_memory_layout else None
237
+ ),
189
238
),
190
239
ValueRef (
191
240
name = "out_ref_third" ,
192
241
src_cpp_name = "std::get<2>(out)" ,
193
242
src_cpp_type = "at::Tensor" ,
194
243
is_out = True ,
244
+ fixed_storage_type = (
245
+ fixed_storage_type [2 ] if fixed_storage_type else None
246
+ ),
247
+ fixed_memory_layout = (
248
+ fixed_memory_layout [2 ] if fixed_memory_layout else None
249
+ ),
195
250
),
196
251
ValueRef (
197
252
name = "out_ref" ,
@@ -302,7 +357,12 @@ def create_value_for( # noqa: C901
302
357
ret_str += f"{ self .graph } { self .dot } "
303
358
ret_str += "add_input_tensor(" if ref .is_in else "add_tensor("
304
359
ret_str += f"{ ref .src_cpp_name } ->sizes().vec(), "
305
- ret_str += f"from_at_scalartype({ ref .src_cpp_name } ->scalar_type())); \n "
360
+ ret_str += f"from_at_scalartype({ ref .src_cpp_name } ->scalar_type()"
361
+ if ref .fixed_storage_type :
362
+ ret_str += f", { ref .fixed_storage_type } "
363
+ if ref .fixed_memory_layout :
364
+ ret_str += f", { ref .fixed_memory_layout } "
365
+ ret_str += "));\n "
306
366
elif prepack :
307
367
ret_str += f"{ self .graph } { self .dot } "
308
368
ret_str += f"add_tensorref({ ref .src_cpp_name } ->sizes().vec(), "
@@ -385,7 +445,12 @@ def create_value_for( # noqa: C901
385
445
elif ref .src_cpp_type == AT_TENSOR and not prepack :
386
446
ret_str += "add_input_tensor(" if ref .is_in else "add_tensor("
387
447
ret_str += f"{ ref .src_cpp_name } .sizes().vec(), "
388
- ret_str += f"from_at_scalartype({ ref .src_cpp_name } .scalar_type())); \n "
448
+ ret_str += f"from_at_scalartype({ ref .src_cpp_name } .scalar_type())"
449
+ if ref .fixed_storage_type :
450
+ ret_str += f", { ref .fixed_storage_type } "
451
+ if ref .fixed_memory_layout :
452
+ ret_str += f", { ref .fixed_memory_layout } "
453
+ ret_str += ");\n "
389
454
elif ref .src_cpp_type == AT_TENSOR and prepack :
390
455
ret_str += f"add_tensorref({ ref .src_cpp_name } .sizes().vec(), "
391
456
ret_str += f"from_at_scalartype({ ref .src_cpp_name } .scalar_type()), "
0 commit comments