Skip to content

Commit 5e1f13b

Browse files
authored
add test_set_value_by_flags_and_idx.py (#4186)
1 parent c5671d7 commit 5e1f13b

File tree

1 file changed

+218
-0
lines changed

1 file changed

+218
-0
lines changed
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
20+
from fastdeploy.model_executor.ops.gpu import set_value_by_flags_and_idx
21+
22+
23+
def set_value_by_flags_and_idx_numpy(
24+
pre_ids_all, input_ids, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx, stop_flags
25+
):
26+
"""Numpy reference implementation"""
27+
result = pre_ids_all.copy()
28+
bs = seq_lens_this_time.shape[0]
29+
for i in range(bs):
30+
if stop_flags[i]:
31+
continue
32+
seq_len_enc = seq_lens_encoder[i]
33+
seq_len_dec = seq_lens_decoder[i]
34+
current_step_idx = step_idx[i]
35+
if seq_len_enc == 0 and seq_len_dec == 0:
36+
continue
37+
if current_step_idx >= 0:
38+
if seq_len_enc > 0:
39+
token_idx = seq_len_enc - 1
40+
token_to_assign = input_ids[i, token_idx]
41+
else:
42+
token_to_assign = input_ids[i, 0]
43+
result[i, current_step_idx] = token_to_assign
44+
return result
45+
46+
47+
class TestSetValueByFlagsAndIdxRandom(unittest.TestCase):
48+
"""Random case testing"""
49+
50+
def setUp(self):
51+
paddle.seed(42)
52+
np.random.seed(42)
53+
batch_size = 10
54+
max_length = 10
55+
max_input_length = 15
56+
57+
# Generate random inputs
58+
self.pre_ids_all_np = np.random.randint(0, 1000, size=(batch_size, max_length), dtype="int64")
59+
self.input_ids_np = np.random.randint(0, 1000, size=(batch_size, max_input_length), dtype="int64")
60+
self.seq_lens_this_time_np = np.random.randint(0, max_input_length, size=(batch_size,), dtype="int32")
61+
self.seq_lens_encoder_np = np.random.randint(0, max_input_length, size=(batch_size,), dtype="int32")
62+
self.seq_lens_decoder_np = np.random.randint(0, max_input_length, size=(batch_size,), dtype="int32")
63+
self.step_idx_np = np.random.randint(0, max_length, size=(batch_size,), dtype="int64")
64+
self.stop_flags_np = np.random.choice([True, False], size=(batch_size,), p=[0.1, 0.9])
65+
66+
def test_set_value_by_flags_and_idx(self):
67+
# NumPy baseline
68+
numpy_out = set_value_by_flags_and_idx_numpy(
69+
self.pre_ids_all_np,
70+
self.input_ids_np,
71+
self.seq_lens_this_time_np,
72+
self.seq_lens_encoder_np,
73+
self.seq_lens_decoder_np,
74+
self.step_idx_np,
75+
self.stop_flags_np,
76+
)
77+
# custom op
78+
pre_ids_all = paddle.to_tensor(self.pre_ids_all_np)
79+
set_value_by_flags_and_idx(
80+
pre_ids_all,
81+
paddle.to_tensor(self.input_ids_np),
82+
paddle.to_tensor(self.seq_lens_this_time_np),
83+
paddle.to_tensor(self.seq_lens_encoder_np),
84+
paddle.to_tensor(self.seq_lens_decoder_np),
85+
paddle.to_tensor(self.step_idx_np),
86+
paddle.to_tensor(self.stop_flags_np),
87+
)
88+
# Ensure outputs match exactly
89+
np.testing.assert_array_equal(numpy_out, pre_ids_all.numpy())
90+
91+
92+
class TestSetValueByFlagsAndIdxCornerCases(unittest.TestCase):
93+
"""Cover corner cases"""
94+
95+
def test_encoder_update(self):
96+
# encoder case: seq_lens_encoder > 0, use last token
97+
pre_ids_all = np.zeros((1, 5), dtype="int64")
98+
input_ids = np.array([[11, 12, 13]], dtype="int64")
99+
seq_lens_this_time = np.array([3], dtype="int32")
100+
seq_lens_encoder = np.array([3], dtype="int32")
101+
seq_lens_decoder = np.array([0], dtype="int32")
102+
step_idx = np.array([0], dtype="int64")
103+
stop_flags = np.array([False], dtype="bool")
104+
105+
expected = set_value_by_flags_and_idx_numpy(
106+
pre_ids_all, input_ids, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx, stop_flags
107+
)
108+
pre_ids_all_tensor = paddle.to_tensor(pre_ids_all)
109+
set_value_by_flags_and_idx(
110+
pre_ids_all_tensor,
111+
paddle.to_tensor(input_ids),
112+
paddle.to_tensor(seq_lens_this_time),
113+
paddle.to_tensor(seq_lens_encoder),
114+
paddle.to_tensor(seq_lens_decoder),
115+
paddle.to_tensor(step_idx),
116+
paddle.to_tensor(stop_flags),
117+
)
118+
np.testing.assert_array_equal(expected, pre_ids_all_tensor.numpy())
119+
120+
def test_decoder_update(self):
121+
# decoder case: seq_lens_encoder=0, use first token
122+
pre_ids_all = np.zeros((1, 4), dtype="int64")
123+
input_ids = np.array([[101, 102]], dtype="int64")
124+
seq_lens_this_time = np.array([2], dtype="int32")
125+
seq_lens_encoder = np.array([0], dtype="int32")
126+
seq_lens_decoder = np.array([2], dtype="int32")
127+
step_idx = np.array([2], dtype="int64")
128+
stop_flags = np.array([False], dtype="bool")
129+
130+
expected = set_value_by_flags_and_idx_numpy(
131+
pre_ids_all, input_ids, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx, stop_flags
132+
)
133+
pre_ids_all_tensor = paddle.to_tensor(pre_ids_all)
134+
set_value_by_flags_and_idx(
135+
pre_ids_all_tensor,
136+
paddle.to_tensor(input_ids),
137+
paddle.to_tensor(seq_lens_this_time),
138+
paddle.to_tensor(seq_lens_encoder),
139+
paddle.to_tensor(seq_lens_decoder),
140+
paddle.to_tensor(step_idx),
141+
paddle.to_tensor(stop_flags),
142+
)
143+
np.testing.assert_array_equal(expected, pre_ids_all_tensor.numpy())
144+
145+
def test_stop_flag(self):
146+
# stop_flags=True, no update
147+
pre_ids_all = np.zeros((1, 3), dtype="int64")
148+
input_ids = np.array([[5, 6, 7]], dtype="int64")
149+
seq_lens_this_time = np.array([3], dtype="int32")
150+
seq_lens_encoder = np.array([3], dtype="int32")
151+
seq_lens_decoder = np.array([0], dtype="int32")
152+
step_idx = np.array([1], dtype="int64")
153+
stop_flags = np.array([True], dtype="bool")
154+
155+
expected = set_value_by_flags_and_idx_numpy(
156+
pre_ids_all, input_ids, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx, stop_flags
157+
)
158+
pre_ids_all_tensor = paddle.to_tensor(pre_ids_all)
159+
set_value_by_flags_and_idx(
160+
pre_ids_all_tensor,
161+
paddle.to_tensor(input_ids),
162+
paddle.to_tensor(seq_lens_this_time),
163+
paddle.to_tensor(seq_lens_encoder),
164+
paddle.to_tensor(seq_lens_decoder),
165+
paddle.to_tensor(step_idx),
166+
paddle.to_tensor(stop_flags),
167+
)
168+
np.testing.assert_array_equal(expected, pre_ids_all_tensor.numpy())
169+
170+
def test_skip_when_both_len_zero(self):
171+
# seq_lens_encoder=0 and seq_lens_decoder=0, skip
172+
pre_ids_all = np.zeros((1, 3), dtype="int64")
173+
input_ids = np.array([[8, 9, 10]], dtype="int64")
174+
seq_lens_this_time = np.array([3], dtype="int32")
175+
seq_lens_encoder = np.array([0], dtype="int32")
176+
seq_lens_decoder = np.array([0], dtype="int32")
177+
step_idx = np.array([0], dtype="int64")
178+
stop_flags = np.array([False], dtype="bool")
179+
180+
expected = pre_ids_all.copy()
181+
pre_ids_all_tensor = paddle.to_tensor(pre_ids_all)
182+
set_value_by_flags_and_idx(
183+
pre_ids_all_tensor,
184+
paddle.to_tensor(input_ids),
185+
paddle.to_tensor(seq_lens_this_time),
186+
paddle.to_tensor(seq_lens_encoder),
187+
paddle.to_tensor(seq_lens_decoder),
188+
paddle.to_tensor(step_idx),
189+
paddle.to_tensor(stop_flags),
190+
)
191+
np.testing.assert_array_equal(expected, pre_ids_all_tensor.numpy())
192+
193+
def test_step_idx_negative(self):
194+
# step_idx < 0, skip
195+
pre_ids_all = np.zeros((1, 3), dtype="int64")
196+
input_ids = np.array([[42, 43, 44]], dtype="int64")
197+
seq_lens_this_time = np.array([3], dtype="int32")
198+
seq_lens_encoder = np.array([2], dtype="int32")
199+
seq_lens_decoder = np.array([1], dtype="int32")
200+
step_idx = np.array([-1], dtype="int64")
201+
stop_flags = np.array([False], dtype="bool")
202+
203+
expected = pre_ids_all.copy()
204+
pre_ids_all_tensor = paddle.to_tensor(pre_ids_all)
205+
set_value_by_flags_and_idx(
206+
pre_ids_all_tensor,
207+
paddle.to_tensor(input_ids),
208+
paddle.to_tensor(seq_lens_this_time),
209+
paddle.to_tensor(seq_lens_encoder),
210+
paddle.to_tensor(seq_lens_decoder),
211+
paddle.to_tensor(step_idx),
212+
paddle.to_tensor(stop_flags),
213+
)
214+
np.testing.assert_array_equal(expected, pre_ids_all_tensor.numpy())
215+
216+
217+
if __name__ == "__main__":
218+
unittest.main()

0 commit comments

Comments
 (0)