@@ -315,47 +315,58 @@ def __init__(self):
315315 super ().__init__ ()
316316 self .linear = nn .Linear (3 , 2 )
317317
318- def forward (self , x , scale = 1.0 ):
318+ def forward (self , x , * , scale ):
319319 return self .linear (x ) * scale
320320
321321 model = ModuleWithKwargs ()
322322 inputs = (torch .randn (4 , 3 ),)
323- kwargs = {"scale" : 2.0 }
323+ kwargs = {"scale" : torch .tensor (2.0 )}
324+
325+ gm = _dynamo_graph_capture_for_export (model )(* inputs , ** kwargs )
324326
325327 with ExitStack () as stack :
326328 # Export joint with descriptors
327329 joint_with_descriptors = aot_export_joint_with_descriptors (
328- stack , model , inputs , kwargs , decompositions = decomposition_table
330+ stack , gm , inputs , kwargs , decompositions = decomposition_table
329331 )
330332
331333 # Test the exported graph structure
332334 graph_code = joint_with_descriptors .graph_module .print_readable (
333335 print_output = False , expanded_def = True
334336 )
335337
338+ # For some reason PYTORCH_TEST_WITH_CROSSREF will add extra spaces.
339+ # I tried to fix this in normalize_gm but there are too many files
340+ # depending on that behavior..
341+ graph_code_str = normalize_gm (graph_code )
342+ graph_code_str = "\n " .join (
343+ [line for line in graph_code_str .split ("\n " ) if len (line .rstrip ()) > 0 ]
344+ )
345+
336346 # Expect test on the printed graph
337347 self .assertExpectedInline (
338- normalize_gm ( graph_code ) ,
348+ graph_code_str ,
339349 """\
340350 class inner_f(torch.nn.Module):
341351 def forward(
342352 self,
343353 primals,
344354 tangents,
345355 ):
346- primals_1: "f32[2, 3]" # ParamAOTInput(target='linear .weight')
347- primals_2: "f32[2]" # ParamAOTInput(target='linear .bias')
356+ primals_1: "f32[2, 3]" # ParamAOTInput(target='L__self___linear .weight')
357+ primals_2: "f32[2]" # ParamAOTInput(target='L__self___linear .bias')
348358 primals_3: "f32[4, 3]" # PlainAOTInput(idx=0)
359+ primals_4: "f32[]" # PlainAOTInput(idx=1)
349360 tangents_1: "f32[4, 2]" # TangentAOTInput(output=PlainAOTOutput(idx=0))
350- primals_1, primals_2, primals_3, primals_4 , tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
361+ primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
351362 transpose: "f32[3, 2]" = torch.ops.prims.transpose.default(primals_1, [1, 0]); primals_1 = None
352363 mm: "f32[4, 2]" = torch.ops.aten.mm.default(primals_3, transpose); transpose = None
353364 mul: "f32[4, 2]" = torch.ops.prims.mul.default(mm, 1.0); mm = None
354365 mul_1: "f32[2]" = torch.ops.prims.mul.default(primals_2, 1.0); primals_2 = None
355366 broadcast_in_dim: "f32[4, 2]" = torch.ops.prims.broadcast_in_dim.default(mul_1, [4, 2], [1]); mul_1 = None
356367 add: "f32[4, 2]" = torch.ops.prims.add.default(mul, broadcast_in_dim); mul = broadcast_in_dim = None
357- mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, 2.0 ); add = None
358- mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, 2.0 ); tangents_1 = None
368+ mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, primals_4 ); add = None
369+ mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, primals_4 ); tangents_1 = primals_4 = None
359370 transpose_1: "f32[2, 4]" = torch.ops.prims.transpose.default(mul_3, [1, 0])
360371 mm_1: "f32[2, 3]" = torch.ops.aten.mm.default(transpose_1, primals_3); transpose_1 = primals_3 = None
361372 transpose_2: "f32[3, 2]" = torch.ops.prims.transpose.default(mm_1, [1, 0]); mm_1 = None
@@ -365,12 +376,11 @@ def forward(
365376 transpose_3: "f32[2, 3]" = torch.ops.prims.transpose.default(transpose_2, [1, 0]); transpose_2 = None
366377 return pytree.tree_unflatten([
367378 mul_2, # PlainAOTOutput(idx=0)
368- transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='linear .weight'))
369- as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='linear .bias'))
379+ transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear .weight'))
380+ as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear .bias'))
370381 None, # None
371382 None, # None
372- ], self._out_spec)
373- """ ,
383+ ], self._out_spec)""" ,
374384 )
375385
376386 # Compile the result
0 commit comments