Skip to content

Commit f8220db

Browse files
committed
test: add unit tests for BlockRefinementScheduler
12 tests covering set_timesteps, get_num_transfer_tokens, step logic (confidence-based commits, threshold behavior, editing, prompt masking, batched inputs, tuple output).
1 parent b3f6cb5 commit f8220db

File tree

1 file changed

+284
-0
lines changed

1 file changed

+284
-0
lines changed
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
import tempfile
2+
import unittest
3+
4+
import torch
5+
6+
from diffusers import BlockRefinementScheduler
7+
8+
9+
class BlockRefinementSchedulerTest(unittest.TestCase):
10+
def get_scheduler(self, **kwargs):
11+
config = {
12+
"block_length": 32,
13+
"num_inference_steps": 8,
14+
"threshold": 0.95,
15+
"editing_threshold": None,
16+
"minimal_topk": 1,
17+
}
18+
config.update(kwargs)
19+
return BlockRefinementScheduler(**config)
20+
21+
def test_set_timesteps(self):
22+
scheduler = self.get_scheduler()
23+
scheduler.set_timesteps(8)
24+
self.assertEqual(scheduler.num_inference_steps, 8)
25+
self.assertEqual(len(scheduler.timesteps), 8)
26+
# Timesteps should count down
27+
self.assertEqual(scheduler.timesteps[0].item(), 7)
28+
self.assertEqual(scheduler.timesteps[-1].item(), 0)
29+
30+
def test_set_timesteps_invalid(self):
31+
scheduler = self.get_scheduler()
32+
with self.assertRaises(ValueError):
33+
scheduler.set_timesteps(0)
34+
35+
def test_get_num_transfer_tokens_even(self):
36+
scheduler = self.get_scheduler()
37+
schedule = scheduler.get_num_transfer_tokens(block_length=32, num_inference_steps=8)
38+
self.assertEqual(schedule.sum().item(), 32)
39+
self.assertEqual(len(schedule), 8)
40+
# 32 / 8 = 4 each, no remainder
41+
self.assertTrue((schedule == 4).all().item())
42+
43+
def test_get_num_transfer_tokens_remainder(self):
44+
scheduler = self.get_scheduler()
45+
schedule = scheduler.get_num_transfer_tokens(block_length=10, num_inference_steps=3)
46+
self.assertEqual(schedule.sum().item(), 10)
47+
self.assertEqual(len(schedule), 3)
48+
# 10 / 3 = 3 base, 1 remainder -> [4, 3, 3]
49+
self.assertEqual(schedule[0].item(), 4)
50+
self.assertEqual(schedule[1].item(), 3)
51+
self.assertEqual(schedule[2].item(), 3)
52+
53+
def test_transfer_schedule_created_on_set_timesteps(self):
54+
scheduler = self.get_scheduler(block_length=16)
55+
scheduler.set_timesteps(4)
56+
self.assertIsNotNone(scheduler._transfer_schedule)
57+
self.assertEqual(scheduler._transfer_schedule.sum().item(), 16)
58+
59+
def test_save_load_config_round_trip(self):
60+
scheduler = self.get_scheduler(block_length=64, threshold=0.8, editing_threshold=0.5, minimal_topk=2)
61+
with tempfile.TemporaryDirectory() as tmpdir:
62+
scheduler.save_config(tmpdir)
63+
loaded = BlockRefinementScheduler.from_pretrained(tmpdir)
64+
65+
self.assertEqual(loaded.config.block_length, 64)
66+
self.assertEqual(loaded.config.threshold, 0.8)
67+
self.assertEqual(loaded.config.editing_threshold, 0.5)
68+
self.assertEqual(loaded.config.minimal_topk, 2)
69+
70+
def test_from_config(self):
71+
scheduler = self.get_scheduler(block_length=16, threshold=0.7)
72+
new_scheduler = BlockRefinementScheduler.from_config(scheduler.config)
73+
self.assertEqual(new_scheduler.config.block_length, 16)
74+
self.assertEqual(new_scheduler.config.threshold, 0.7)
75+
76+
def test_step_commits_tokens(self):
77+
"""Verify that step() commits mask tokens based on confidence."""
78+
scheduler = self.get_scheduler(block_length=8)
79+
scheduler.set_timesteps(2)
80+
81+
batch_size, block_length = 1, 8
82+
mask_id = 99
83+
84+
# All positions are masked
85+
sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long)
86+
sampled_tokens = torch.arange(block_length, dtype=torch.long).unsqueeze(0)
87+
# Confidence decreasing: first tokens are most confident
88+
sampled_probs = torch.tensor([[0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2]])
89+
90+
out = scheduler.step(
91+
sampled_tokens=sampled_tokens,
92+
sampled_probs=sampled_probs,
93+
timestep=0,
94+
sample=sample,
95+
mask_token_id=mask_id,
96+
threshold=0.95,
97+
return_dict=True,
98+
)
99+
100+
# With 8 tokens and 2 steps, first step should commit 4 tokens
101+
committed = out.transfer_index[0].sum().item()
102+
self.assertEqual(committed, 4)
103+
# The 4 most confident (highest prob) should be committed
104+
self.assertTrue(out.transfer_index[0, 0].item())
105+
self.assertTrue(out.transfer_index[0, 1].item())
106+
self.assertTrue(out.transfer_index[0, 2].item())
107+
self.assertTrue(out.transfer_index[0, 3].item())
108+
109+
def test_step_threshold_commits_all_above(self):
110+
"""When enough tokens exceed threshold, commit all of them (not just num_to_transfer)."""
111+
scheduler = self.get_scheduler(block_length=8)
112+
scheduler.set_timesteps(4) # 2 tokens per step
113+
114+
batch_size, block_length = 1, 8
115+
mask_id = 99
116+
117+
sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long)
118+
sampled_tokens = torch.arange(block_length, dtype=torch.long).unsqueeze(0)
119+
# 5 tokens above threshold of 0.5
120+
sampled_probs = torch.tensor([[0.9, 0.8, 0.7, 0.6, 0.55, 0.1, 0.1, 0.1]])
121+
122+
out = scheduler.step(
123+
sampled_tokens=sampled_tokens,
124+
sampled_probs=sampled_probs,
125+
timestep=0,
126+
sample=sample,
127+
mask_token_id=mask_id,
128+
threshold=0.5,
129+
return_dict=True,
130+
)
131+
132+
# All 5 above threshold should be committed (more than num_to_transfer=2)
133+
committed = out.transfer_index[0].sum().item()
134+
self.assertEqual(committed, 5)
135+
136+
def test_step_no_editing_by_default(self):
137+
"""Without editing_threshold, no non-mask tokens should be changed."""
138+
scheduler = self.get_scheduler(block_length=4)
139+
scheduler.set_timesteps(2)
140+
141+
sample = torch.tensor([[10, 20, 99, 99]], dtype=torch.long)
142+
sampled_tokens = torch.tensor([[50, 60, 70, 80]], dtype=torch.long)
143+
sampled_probs = torch.tensor([[0.99, 0.99, 0.99, 0.99]])
144+
145+
out = scheduler.step(
146+
sampled_tokens=sampled_tokens,
147+
sampled_probs=sampled_probs,
148+
timestep=0,
149+
sample=sample,
150+
mask_token_id=99,
151+
editing_threshold=None,
152+
return_dict=True,
153+
)
154+
155+
# Non-mask positions should not be edited
156+
self.assertFalse(out.editing_transfer_index.any().item())
157+
# Only mask positions should be committed
158+
self.assertFalse(out.transfer_index[0, 0].item())
159+
self.assertFalse(out.transfer_index[0, 1].item())
160+
161+
def test_step_editing_replaces_tokens(self):
162+
"""With editing_threshold, non-mask tokens with high confidence and different prediction get replaced."""
163+
scheduler = self.get_scheduler(block_length=4)
164+
scheduler.set_timesteps(2)
165+
166+
sample = torch.tensor([[10, 20, 99, 99]], dtype=torch.long)
167+
# Token 0: model predicts 50 (different from 10) with high confidence
168+
# Token 1: model predicts 20 (same as current) — should NOT edit
169+
sampled_tokens = torch.tensor([[50, 20, 70, 80]], dtype=torch.long)
170+
sampled_probs = torch.tensor([[0.99, 0.99, 0.5, 0.5]])
171+
172+
out = scheduler.step(
173+
sampled_tokens=sampled_tokens,
174+
sampled_probs=sampled_probs,
175+
timestep=0,
176+
sample=sample,
177+
mask_token_id=99,
178+
editing_threshold=0.8,
179+
return_dict=True,
180+
)
181+
182+
# Token 0 should be edited (different prediction, high confidence)
183+
self.assertTrue(out.editing_transfer_index[0, 0].item())
184+
# Token 1 should NOT be edited (same prediction)
185+
self.assertFalse(out.editing_transfer_index[0, 1].item())
186+
# prev_sample should reflect the edit
187+
self.assertEqual(out.prev_sample[0, 0].item(), 50)
188+
189+
def test_step_prompt_mask_prevents_editing(self):
190+
"""Prompt positions should never be edited even with editing enabled."""
191+
scheduler = self.get_scheduler(block_length=4)
192+
scheduler.set_timesteps(2)
193+
194+
sample = torch.tensor([[10, 20, 99, 99]], dtype=torch.long)
195+
sampled_tokens = torch.tensor([[50, 60, 70, 80]], dtype=torch.long)
196+
sampled_probs = torch.tensor([[0.99, 0.99, 0.99, 0.99]])
197+
prompt_mask = torch.tensor([True, True, False, False])
198+
199+
out = scheduler.step(
200+
sampled_tokens=sampled_tokens,
201+
sampled_probs=sampled_probs,
202+
timestep=0,
203+
sample=sample,
204+
mask_token_id=99,
205+
editing_threshold=0.5,
206+
prompt_mask=prompt_mask,
207+
return_dict=True,
208+
)
209+
210+
# Prompt positions should not be edited
211+
self.assertFalse(out.editing_transfer_index[0, 0].item())
212+
self.assertFalse(out.editing_transfer_index[0, 1].item())
213+
214+
def test_step_return_tuple(self):
215+
"""Verify tuple output when return_dict=False."""
216+
scheduler = self.get_scheduler(block_length=4)
217+
scheduler.set_timesteps(2)
218+
219+
sample = torch.full((1, 4), 99, dtype=torch.long)
220+
sampled_tokens = torch.arange(4, dtype=torch.long).unsqueeze(0)
221+
sampled_probs = torch.ones(1, 4)
222+
223+
result = scheduler.step(
224+
sampled_tokens=sampled_tokens,
225+
sampled_probs=sampled_probs,
226+
timestep=0,
227+
sample=sample,
228+
mask_token_id=99,
229+
return_dict=False,
230+
)
231+
232+
self.assertIsInstance(result, tuple)
233+
self.assertEqual(len(result), 5)
234+
235+
def test_step_batched(self):
236+
"""Verify step works with batch_size > 1."""
237+
scheduler = self.get_scheduler(block_length=4)
238+
scheduler.set_timesteps(2)
239+
240+
batch_size = 3
241+
mask_id = 99
242+
sample = torch.full((batch_size, 4), mask_id, dtype=torch.long)
243+
sampled_tokens = torch.arange(4, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
244+
sampled_probs = torch.rand(batch_size, 4)
245+
246+
out = scheduler.step(
247+
sampled_tokens=sampled_tokens,
248+
sampled_probs=sampled_probs,
249+
timestep=0,
250+
sample=sample,
251+
mask_token_id=mask_id,
252+
return_dict=True,
253+
)
254+
255+
self.assertEqual(out.prev_sample.shape, (batch_size, 4))
256+
self.assertEqual(out.transfer_index.shape, (batch_size, 4))
257+
258+
def test_step_output_shape_matches_input(self):
259+
"""All output tensors should match the input sample shape."""
260+
scheduler = self.get_scheduler(block_length=8)
261+
scheduler.set_timesteps(4)
262+
263+
sample = torch.full((2, 8), 99, dtype=torch.long)
264+
sampled_tokens = torch.zeros_like(sample)
265+
sampled_probs = torch.rand(2, 8)
266+
267+
out = scheduler.step(
268+
sampled_tokens=sampled_tokens,
269+
sampled_probs=sampled_probs,
270+
timestep=0,
271+
sample=sample,
272+
mask_token_id=99,
273+
return_dict=True,
274+
)
275+
276+
self.assertEqual(out.prev_sample.shape, sample.shape)
277+
self.assertEqual(out.transfer_index.shape, sample.shape)
278+
self.assertEqual(out.editing_transfer_index.shape, sample.shape)
279+
self.assertEqual(out.sampled_tokens.shape, sample.shape)
280+
self.assertEqual(out.sampled_probs.shape, sample.shape)
281+
282+
283+
if __name__ == "__main__":
284+
unittest.main()

0 commit comments

Comments
 (0)