@@ -595,33 +595,78 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
595595 self .assertEqual (counter , 1 )
596596
597597 def test_compile_fix_broken_ops (self ) -> None :
598- # When pass an input of more than 4 dimensions to Linear
599- # aten._unsafe_view is used under the hood
600- x = torch .randn ([2 , 3 , 4 , 5 ])
601- model : torch .nn .Linear = torch .nn .Linear (5 , 5 )
602-
603- class Foo (torch .nn .Module ):
604- def __init__ (self ):
598+ class ExportableLoop (nn .Module ):
599+ def __init__ (self , hidden_size , out_channels ):
605600 super ().__init__ ()
606- self .model = model
607-
608- def forward (self , inp : torch .Tensor ) -> torch .Tensor :
609- return self .model (inp )
610-
611- f = Foo ()
601+ self .hidden_size = hidden_size
602+ self .B = nn .Parameter (torch .randn (hidden_size , 1 )) # (H, in_channels)
603+ self .C = nn .Parameter (
604+ torch .randn (out_channels , hidden_size )
605+ ) # (C_out, H)
606+ A = torch .randn (2 , hidden_size )
607+ self .A_real = nn .Parameter (A [0 ].clone ())
608+ self .A_imag = nn .Parameter (A [1 ].clone ())
609+
610+ def update_state (self , h , x_t ):
611+ # h: [B, 2, H], x_t: [B, H]
612+ hr , hi = h [:, 0 , :], h [:, 1 , :] # [B, H]
613+ hrn = hr * self .A_real - hi * self .A_imag + x_t # [B, H]
614+ hin = hi * self .A_real + hr * self .A_imag # [B, H]
615+ hn = torch .stack ([hrn , hin ], dim = 1 ) # [B, 2, H]
616+ return hn , hrn
617+
618+ def forward (self , u ):
619+ # u: [B, 1, T]
620+ x = torch .matmul (self .B , u ) # (B, H, T)
621+ B , H , T = x .shape
622+
623+ h = torch .zeros (B , 2 , H , device = x .device , dtype = x .dtype ) # [B, 2, H]
624+ h_accum = torch .zeros (
625+ B , H , T , device = x .device , dtype = x .dtype
626+ ) # [B, H, T]
627+ i = torch .tensor (0 , device = x .device , dtype = torch .int64 )
628+ one = torch .tensor (1 , device = x .device , dtype = torch .int64 )
629+
630+ def cond (i , h , h_accum ):
631+ return i < T
632+
633+ def body (i , h , h_accum ):
634+ x_t = x .index_select (- 1 , i .unsqueeze (0 )).squeeze (
635+ - 1
636+ ) # ✅ safe for export
637+ h , hr = self .update_state (h , x_t ) # h: [B, 2, H], hr: [B, H]
638+ h_accum = h_accum .index_copy (
639+ - 1 , i .unsqueeze (0 ), hr .unsqueeze (- 1 )
640+ ) # [B, H, T]
641+ i_next = i + one
642+ return i_next , h , h_accum
643+
644+ _ , h , h_accum = torch ._higher_order_ops .while_loop (
645+ cond , body , (i , h , h_accum )
646+ )
647+ y = torch .matmul (self .C , h_accum ).transpose (0 , 1 ) # (B, C_out, T)
648+ return y
612649
613- # ReplaceBrokenOpsWithFunctionalOpsPass is used in to_edge()
650+ # Instantiate and export
651+ model = ExportableLoop (hidden_size = 128 , out_channels = 10 )
652+ inp = torch .randn (1 , 1 , 32 ) # (B, in_channels=1, T=32)
653+ ep = export (model , (inp ,))
614654 prog = to_edge (
615- export ( f , ( x ,), strict = True ) ,
655+ ep ,
616656 compile_config = exir .EdgeCompileConfig (_check_ir_validity = False ),
617657 )
618658 gm = prog .exported_program ().graph_module
619659 count_after = 0
620660 for node in gm .graph .nodes :
621- if node .target == torch .ops .aten ._unsafe_view .default :
661+ if (
662+ node .target == torch .ops .aten .squeeze .dims
663+ or node .target == torch .ops .aten .select .int
664+ ):
622665 count_after += 1
623666 self .assertEqual (count_after , 0 )
624- self .assertTrue (torch .allclose (prog .exported_program ().module ()(x ), f (x )))
667+ self .assertTrue (
668+ torch .allclose (prog .exported_program ().module ()(inp ), model (inp ))
669+ )
625670
626671 def test_convert_symb_ops (self ) -> None :
627672 class Foo (torch .nn .Module ):
0 commit comments