@@ -37,6 +37,7 @@ def __init__(
3737 if self .has_weight :
3838 self .weight = nn .Parameter (self .weight )
3939
40+ @torch .compile (dynamic = True )
4041 def forward_native (
4142 self ,
4243 x : torch .Tensor ,
@@ -89,6 +90,7 @@ class ScaleResidual(nn.Module):
8990 def __init__ (self , prefix : str = "" ):
9091 super ().__init__ ()
9192
93+ @torch .compile (dynamic = True )
9294 def forward (self , residual : torch .Tensor , x : torch .Tensor ,
9395 gate : torch .Tensor ) -> torch .Tensor :
9496 """Apply gated residual connection."""
@@ -128,6 +130,7 @@ def __init__(
128130 else :
129131 raise NotImplementedError (f"Norm type { norm_type } not implemented" )
130132
133+ @torch .compile (dynamic = True )
131134 def forward (self , residual : torch .Tensor , x : torch .Tensor ,
132135 gate : torch .Tensor , shift : torch .Tensor ,
133136 scale : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
@@ -178,6 +181,7 @@ def __init__(
178181 else :
179182 raise NotImplementedError (f"Norm type { norm_type } not implemented" )
180183
184+ @torch .compile (dynamic = True )
181185 def forward (self , x : torch .Tensor , shift : torch .Tensor ,
182186 scale : torch .Tensor ) -> torch .Tensor :
183187 """Apply ln followed by scale and shift in a single fused operation."""
0 commit comments