Skip to content

Commit b392e6a

Browse files
authored
Add torch.compile for all small ops (#432)
1 parent 0991003 commit b392e6a

File tree

5 files changed

+13
-0
lines changed

5 files changed

+13
-0
lines changed

fastvideo/v1/layers/activation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class SiluAndMul(CustomOp):
2525
def __init__(self) -> None:
2626
super().__init__()
2727

28+
@torch.compile(dynamic=True)
2829
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
2930
"""PyTorch-native implementation equivalent to forward()."""
3031
d = x.shape[-1] // 2
@@ -48,6 +49,7 @@ def __init__(self, approximate: str = "none"):
4849
if approximate not in ("none", "tanh"):
4950
raise ValueError(f"Unknown approximate mode: {approximate}")
5051

52+
@torch.compile(dynamic=True)
5153
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
5254
"""PyTorch-native implementation equivalent to forward()."""
5355
d = x.shape[-1] // 2
@@ -63,6 +65,7 @@ class NewGELU(CustomOp):
6365
def __init__(self):
6466
super().__init__()
6567

68+
@torch.compile(dynamic=True)
6669
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
6770
"""PyTorch-native implementation equivalent to forward()."""
6871
c = math.sqrt(2.0 / math.pi)
@@ -76,6 +79,7 @@ class QuickGELU(CustomOp):
7679
def __init__(self):
7780
super().__init__()
7881

82+
@torch.compile(dynamic=True)
7983
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
8084
"""PyTorch-native implementation equivalent to forward()."""
8185
return x * torch.sigmoid(1.702 * x)

fastvideo/v1/layers/layernorm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
if self.has_weight:
3838
self.weight = nn.Parameter(self.weight)
3939

40+
@torch.compile(dynamic=True)
4041
def forward_native(
4142
self,
4243
x: torch.Tensor,
@@ -89,6 +90,7 @@ class ScaleResidual(nn.Module):
8990
def __init__(self, prefix: str = ""):
9091
super().__init__()
9192

93+
@torch.compile(dynamic=True)
9294
def forward(self, residual: torch.Tensor, x: torch.Tensor,
9395
gate: torch.Tensor) -> torch.Tensor:
9496
"""Apply gated residual connection."""
@@ -128,6 +130,7 @@ def __init__(
128130
else:
129131
raise NotImplementedError(f"Norm type {norm_type} not implemented")
130132

133+
@torch.compile(dynamic=True)
131134
def forward(self, residual: torch.Tensor, x: torch.Tensor,
132135
gate: torch.Tensor, shift: torch.Tensor,
133136
scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -178,6 +181,7 @@ def __init__(
178181
else:
179182
raise NotImplementedError(f"Norm type {norm_type} not implemented")
180183

184+
@torch.compile(dynamic=True)
181185
def forward(self, x: torch.Tensor, shift: torch.Tensor,
182186
scale: torch.Tensor) -> torch.Tensor:
183187
"""Apply ln followed by scale and shift in a single fused operation."""

fastvideo/v1/layers/mlp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
bias=bias,
4040
params_dtype=dtype)
4141

42+
@torch.compile(dynamic=True)
4243
def forward(self, x: torch.Tensor) -> torch.Tensor:
4344
x, _ = self.fc_in(x)
4445
x = self.act(x)

fastvideo/v1/layers/rotary_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
4444
return x.flatten(-2)
4545

4646

47+
# @torch.compile(dynamic=True)
4748
def _apply_rotary_emb(
4849
x: torch.Tensor,
4950
cos: torch.Tensor,

fastvideo/v1/layers/visual_embedding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self,
5353
dtype=dtype)
5454
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
5555

56+
@torch.compile(dynamic=True)
5657
def forward(self, x):
5758
x = self.proj(x)
5859
if self.flatten:
@@ -98,6 +99,7 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
9899
return t_emb
99100

100101

102+
@torch.compile(dynamic=True)
101103
def timestep_embedding(t: torch.Tensor,
102104
dim: int,
103105
max_period: int = 10000,
@@ -145,6 +147,7 @@ def __init__(
145147
params_dtype=dtype)
146148
self.act = get_act_fn(act_layer)
147149

150+
@torch.compile(dynamic=True)
148151
def forward(self, x: torch.Tensor) -> torch.Tensor:
149152
x = self.act(x)
150153
x, _ = self.linear(x)

0 commit comments

Comments
 (0)