Skip to content

Commit 0ea8712

Browse files
authored
fix op tests (#3398)
1 parent 2e78311 commit 0ea8712

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

test/operators/test_fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ def split_forward(self, hidden_states):
165165
permute_indices_per_token,
166166
top_k_weights,
167167
top_k_indices,
168-
) = moe_expert_dispatch(hidden_states, scores, None, self.top_k, False, topk_only_mode=True)
168+
expert_idx_per_token,
169+
) = moe_expert_dispatch(hidden_states, scores, None, None, self.top_k, False, topk_only_mode=True)
169170

170171
# Process through experts
171172
ffn_out = moe_expert_ffn(

test/operators/test_rejection_top_p_sampling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ def setUp(self):
3535
def test_top_p_sampling_reject_case1(self):
3636
"""Test with fixed top_p=0.8 and different random seeds"""
3737
top_p_paddle = paddle.full((self.batch_size,), 0.8)
38+
top_k_paddle = paddle.full((self.batch_size,), 20).cast("int64")
3839

3940
# Test with different seeds
4041
for seed in [1024, 2033, 2033]:
41-
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, seed)
42+
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, top_k_paddle, seed)
4243
self._validate_samples(samples)
4344

4445
# Basic validation
@@ -48,13 +49,12 @@ def test_top_p_sampling_reject_case1(self):
4849
def test_top_p_sampling_reject_case2(self):
4950
"""Test with varying top_p values across batch"""
5051
top_p_paddle = paddle.uniform(shape=[self.batch_size], min=0.1, max=1.0)
51-
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, -1)
52-
52+
top_k_paddle = paddle.full((self.batch_size,), 20).cast("int64")
53+
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, top_k_paddle, -1)
5354
self._validate_samples(samples)
5455

5556
# Additional check that we're getting different results for different top_p
5657
unique_samples = len(paddle.unique(samples))
57-
print(f"Unique samples: {unique_samples}")
5858
self.assertGreater(unique_samples, 1) # Should have some diversity
5959

6060
def _validate_samples(self, samples):

0 commit comments

Comments
 (0)