Skip to content

Commit 650c03f

Browse files
committed
test
1 parent b9e9965 commit 650c03f

File tree

4 files changed

+36
-16
lines changed

4 files changed

+36
-16
lines changed

tests/pipelines/cosmos/test_cosmos.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,16 @@ def test_inference(self):
153153
inputs = self.get_dummy_inputs(device)
154154
video = pipe(**inputs).frames
155155
generated_video = video[0]
156-
157156
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
158-
expected_video = torch.randn(9, 3, 32, 32)
159-
max_diff = np.abs(generated_video - expected_video).max()
160-
self.assertLessEqual(max_diff, 1e10)
157+
158+
# fmt: off
159+
expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
160+
# fmt: on
161+
162+
generated_slice = generated_video.flatten()
163+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
164+
print("txt2video:", [round(x, 4) for x in generated_slice.tolist()])
165+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
161166

162167
def test_callback_inputs(self):
163168
sig = inspect.signature(self.pipeline_class.__call__)

tests/pipelines/cosmos/test_cosmos2_text2image.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,16 @@ def test_inference(self):
140140
inputs = self.get_dummy_inputs(device)
141141
image = pipe(**inputs).images
142142
generated_image = image[0]
143-
144143
self.assertEqual(generated_image.shape, (3, 32, 32))
145-
expected_video = torch.randn(3, 32, 32)
146-
max_diff = np.abs(generated_image - expected_video).max()
147-
self.assertLessEqual(max_diff, 1e10)
144+
145+
# fmt: off
146+
expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
147+
# fmt: on
148+
149+
generated_slice = generated_image.flatten()
150+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
151+
print("txt2img:", [round(x, 4) for x in generated_slice.tolist()])
152+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
148153

149154
def test_callback_inputs(self):
150155
sig = inspect.signature(self.pipeline_class.__call__)

tests/pipelines/cosmos/test_cosmos2_video2world.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,16 @@ def test_inference(self):
147147
inputs = self.get_dummy_inputs(device)
148148
video = pipe(**inputs).frames
149149
generated_video = video[0]
150-
151150
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
152-
expected_video = torch.randn(9, 3, 32, 32)
153-
max_diff = np.abs(generated_video - expected_video).max()
154-
self.assertLessEqual(max_diff, 1e10)
151+
152+
# fmt: off
153+
expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
154+
# fmt: on
155+
156+
generated_slice = generated_video.flatten()
157+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
158+
print("cosmos2video2world:", [round(x, 4) for x in generated_slice.tolist()])
159+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
155160

156161
def test_components_function(self):
157162
init_components = self.get_dummy_components()

tests/pipelines/cosmos/test_cosmos_video2world.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,16 @@ def test_inference(self):
159159
inputs = self.get_dummy_inputs(device)
160160
video = pipe(**inputs).frames
161161
generated_video = video[0]
162-
163162
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
164-
expected_video = torch.randn(9, 3, 32, 32)
165-
max_diff = np.abs(generated_video - expected_video).max()
166-
self.assertLessEqual(max_diff, 1e10)
163+
164+
# fmt: off
165+
expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
166+
# fmt: on
167+
168+
generated_slice = generated_video.flatten()
169+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
170+
print("vid2world:", [round(x, 4) for x in generated_slice.tolist()])
171+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
167172

168173
def test_components_function(self):
169174
init_components = self.get_dummy_components()

0 commit comments

Comments
 (0)