@@ -315,58 +315,47 @@ def __init__(self):
315315 super ().__init__ ()
316316 self .linear = nn .Linear (3 , 2 )
317317
318- def forward (self , x , * , scale ):
318+ def forward (self , x , scale = 1.0 ):
319319 return self .linear (x ) * scale
320320
321321 model = ModuleWithKwargs ()
322322 inputs = (torch .randn (4 , 3 ),)
323- kwargs = {"scale" : torch .tensor (2.0 )}
324-
325- gm = _dynamo_graph_capture_for_export (model )(* inputs , ** kwargs )
323+ kwargs = {"scale" : 2.0 }
326324
327325 with ExitStack () as stack :
328326 # Export joint with descriptors
329327 joint_with_descriptors = aot_export_joint_with_descriptors (
330- stack , gm , inputs , kwargs , decompositions = decomposition_table
328+ stack , model , inputs , kwargs , decompositions = decomposition_table
331329 )
332330
333331 # Test the exported graph structure
334332 graph_code = joint_with_descriptors .graph_module .print_readable (
335333 print_output = False , expanded_def = True
336334 )
337335
338- # For some reason PYTORCH_TEST_WITH_CROSSREF will add extra spaces.
339- # I tried to fix this in normalize_gm but there are too many files
340- # depending on that behavior..
341- graph_code_str = normalize_gm (graph_code )
342- graph_code_str = "\n " .join (
343- [line for line in graph_code_str .split ("\n " ) if len (line .rstrip ()) > 0 ]
344- )
345-
346336 # Expect test on the printed graph
347337 self .assertExpectedInline (
348- graph_code_str ,
338+ normalize_gm ( graph_code ) ,
349339 """\
350340 class inner_f(torch.nn.Module):
351341 def forward(
352342 self,
353343 primals,
354344 tangents,
355345 ):
356- primals_1: "f32[2, 3]" # ParamAOTInput(target='L__self___linear .weight')
357- primals_2: "f32[2]" # ParamAOTInput(target='L__self___linear .bias')
346+ primals_1: "f32[2, 3]" # ParamAOTInput(target='linear .weight')
347+ primals_2: "f32[2]" # ParamAOTInput(target='linear .bias')
358348 primals_3: "f32[4, 3]" # PlainAOTInput(idx=0)
359- primals_4: "f32[]" # PlainAOTInput(idx=1)
360349 tangents_1: "f32[4, 2]" # TangentAOTInput(output=PlainAOTOutput(idx=0))
361- primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
350+ primals_1, primals_2, primals_3, primals_4 , tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
362351 transpose: "f32[3, 2]" = torch.ops.prims.transpose.default(primals_1, [1, 0]); primals_1 = None
363352 mm: "f32[4, 2]" = torch.ops.aten.mm.default(primals_3, transpose); transpose = None
364353 mul: "f32[4, 2]" = torch.ops.prims.mul.default(mm, 1.0); mm = None
365354 mul_1: "f32[2]" = torch.ops.prims.mul.default(primals_2, 1.0); primals_2 = None
366355 broadcast_in_dim: "f32[4, 2]" = torch.ops.prims.broadcast_in_dim.default(mul_1, [4, 2], [1]); mul_1 = None
367356 add: "f32[4, 2]" = torch.ops.prims.add.default(mul, broadcast_in_dim); mul = broadcast_in_dim = None
368- mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, primals_4 ); add = None
369- mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, primals_4 ); tangents_1 = primals_4 = None
357+ mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, 2.0 ); add = None
358+ mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, 2.0 ); tangents_1 = None
370359 transpose_1: "f32[2, 4]" = torch.ops.prims.transpose.default(mul_3, [1, 0])
371360 mm_1: "f32[2, 3]" = torch.ops.aten.mm.default(transpose_1, primals_3); transpose_1 = primals_3 = None
372361 transpose_2: "f32[3, 2]" = torch.ops.prims.transpose.default(mm_1, [1, 0]); mm_1 = None
@@ -376,11 +365,12 @@ def forward(
376365 transpose_3: "f32[2, 3]" = torch.ops.prims.transpose.default(transpose_2, [1, 0]); transpose_2 = None
377366 return pytree.tree_unflatten([
378367 mul_2, # PlainAOTOutput(idx=0)
379- transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear .weight'))
380- as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear .bias'))
368+ transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='linear .weight'))
369+ as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='linear .bias'))
381370 None, # None
382371 None, # None
383- ], self._out_spec)""" ,
372+ ], self._out_spec)
373+ """ ,
384374 )
385375
386376 # Compile the result
0 commit comments