Skip to content

Commit d763cc5

Browse files
committed
Fix t5 encoder tests
1 parent 2746bec commit d763cc5

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

fastvideo/models/encoders/t5.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ def __init__(self,
181181

182182
self.qkv_proj = QKVParallelLinear(
183183
self.d_model,
184-
self.d_model // self.total_num_heads,
184+
#self.d_model // self.total_num_heads,
185+
self.key_value_proj_dim,
185186
self.total_num_heads,
186187
self.total_num_kv_heads,
187188
bias=False,
@@ -199,7 +200,8 @@ def __init__(self,
199200
padding_size=self.relative_attention_num_buckets,
200201
quant_config=quant_config)
201202
self.o = RowParallelLinear(
202-
self.d_model,
203+
#self.d_model,
204+
self.total_num_heads * self.key_value_proj_dim,
203205
self.d_model,
204206
bias=False,
205207
quant_config=quant_config,
@@ -298,10 +300,12 @@ def forward(
298300
) -> torch.Tensor:
299301
bs, seq_len, _ = hidden_states.shape
300302
num_seqs = bs
301-
n, c = self.n_heads, self.d_model // self.total_num_heads
303+
#n, c = self.n_heads, self.d_model // self.total_num_heads
304+
n, c = self.n_heads, self.key_value_proj_dim
302305
qkv, _ = self.qkv_proj(hidden_states)
303306
# Projection of 'own' hidden state (self-attention). No GQA here.
304-
q, k, v = qkv.split(self.inner_dim, dim=-1)
307+
#q, k, v = qkv.split(self.inner_dim, dim=-1)
308+
q, k, v = qkv.split(self.qkv_proj.output_sizes, dim=-1)
305309
q = q.reshape(bs, seq_len, n, c)
306310
k = k.reshape(bs, seq_len, n, c)
307311
v = v.reshape(bs, seq_len, n, c)

fastvideo/tests/encoders/test_t5_encoder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,15 @@ def test_t5_large_encoder():
169169
# Check number of parameters
170170
logger.info("Model1 has %s parameters", len(params1))
171171
logger.info("Model2 has %s parameters", len(params2))
172+
173+
# # Print parameter names for comparison
174+
# logger.info("Model1 parameters:")
175+
# for name in sorted(params1.keys()):
176+
# logger.info(" %s: %s", name, params1[name].shape)
177+
178+
# logger.info("Model2 parameters:")
179+
# for name in sorted(params2.keys()):
180+
# logger.info(" %s: %s", name, params2[name].shape)
172181

173182
weight_diffs = []
174183
# check if embed_tokens are the same

0 commit comments

Comments
 (0)