Skip to content

Commit fe451c3

Browse files
committed
update
1 parent 0f1a4e0 commit fe451c3

File tree

16 files changed

+682
-709
lines changed

16 files changed

+682
-709
lines changed

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def pytest_configure(config):
4646
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
4747
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
4848
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
49+
config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality")
4950

5051

5152
def pytest_addoption(parser):

tests/models/test_modeling_common.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,9 @@ def test_local_files_only_with_sharded_checkpoint(self):
317317
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
318318
)
319319

320-
assert all(
321-
torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())
322-
), "Model parameters don't match!"
320+
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
321+
"Model parameters don't match!"
322+
)
323323

324324
# Remove a shard file
325325
cached_shard_file = try_to_load_from_cache(
@@ -335,9 +335,9 @@ def test_local_files_only_with_sharded_checkpoint(self):
335335

336336
# Verify error mentions the missing shard
337337
error_msg = str(context.exception)
338-
assert (
339-
cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg
340-
), f"Expected error about missing shard, got: {error_msg}"
338+
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
339+
f"Expected error about missing shard, got: {error_msg}"
340+
)
341341

342342
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
343343
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
@@ -354,9 +354,9 @@ def test_one_request_upon_cached(self):
354354
)
355355

356356
download_requests = [r.method for r in m.request_history]
357-
assert (
358-
download_requests.count("HEAD") == 3
359-
), "3 HEAD requests one for config, one for model, and one for shard index file."
357+
assert download_requests.count("HEAD") == 3, (
358+
"3 HEAD requests one for config, one for model, and one for shard index file."
359+
)
360360
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
361361

362362
with requests_mock.mock(real_http=True) as m:
@@ -368,9 +368,9 @@ def test_one_request_upon_cached(self):
368368
)
369369

370370
cache_requests = [r.method for r in m.request_history]
371-
assert (
372-
"HEAD" == cache_requests[0] and len(cache_requests) == 2
373-
), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
371+
assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
372+
"We should call only `model_info` to check for commit hash and knowing if shard index is present."
373+
)
374374

375375
def test_weight_overwrite(self):
376376
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:

tests/models/testing_utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .attention import AttentionTesterMixin
1+
from .attention import AttentionTesterMixin, ContextParallelTesterMixin
22
from .common import ModelTesterMixin
33
from .compile import TorchCompileTesterMixin
44
from .ip_adapter import IPAdapterTesterMixin
@@ -17,6 +17,7 @@
1717

1818

1919
__all__ = [
20+
"ContextParallelTesterMixin",
2021
"AttentionTesterMixin",
2122
"BitsAndBytesTesterMixin",
2223
"CPUOffloadTesterMixin",

tests/models/testing_utils/attention.py

Lines changed: 91 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import os
17+
1618
import pytest
1719
import torch
20+
import torch.multiprocessing as mp
1821

22+
from diffusers.models._modeling_parallel import ContextParallelConfig
1923
from diffusers.models.attention import AttentionModuleMixin
2024
from diffusers.models.attention_processor import (
2125
AttnProcessor,
2226
)
2327

24-
from ...testing_utils import is_attention, torch_device
28+
from ...testing_utils import is_attention, is_context_parallel, require_torch_multi_accelerator, torch_device
2529

2630

2731
@is_attention
@@ -85,9 +89,9 @@ def test_fuse_unfuse_qkv_projections(self):
8589
output_after_fusion = output_after_fusion.to_tuple()[0]
8690

8791
# Verify outputs match
88-
assert torch.allclose(
89-
output_before_fusion, output_after_fusion, atol=self.base_precision
90-
), "Output should not change after fusing projections"
92+
assert torch.allclose(output_before_fusion, output_after_fusion, atol=self.base_precision), (
93+
"Output should not change after fusing projections"
94+
)
9195

9296
# Unfuse projections
9397
model.unfuse_qkv_projections()
@@ -106,9 +110,9 @@ def test_fuse_unfuse_qkv_projections(self):
106110
output_after_unfusion = output_after_unfusion.to_tuple()[0]
107111

108112
# Verify outputs still match
109-
assert torch.allclose(
110-
output_before_fusion, output_after_unfusion, atol=self.base_precision
111-
), "Output should match original after unfusing projections"
113+
assert torch.allclose(output_before_fusion, output_after_unfusion, atol=self.base_precision), (
114+
"Output should match original after unfusing projections"
115+
)
112116

113117
def test_get_set_processor(self):
114118
init_dict = self.get_init_dict()
@@ -177,3 +181,83 @@ def test_attention_processor_count_mismatch_raises_error(self):
177181
model.set_attn_processor(wrong_processors)
178182

179183
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
184+
185+
186+
def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue):
187+
try:
188+
# Setup distributed environment
189+
os.environ["MASTER_ADDR"] = "localhost"
190+
os.environ["MASTER_PORT"] = "12355"
191+
192+
torch.distributed.init_process_group(
193+
backend="nccl",
194+
init_method="env://",
195+
world_size=world_size,
196+
rank=rank,
197+
)
198+
torch.cuda.set_device(rank)
199+
device = torch.device(f"cuda:{rank}")
200+
201+
model = model_class(**init_dict)
202+
model.to(device)
203+
model.eval()
204+
205+
inputs_on_device = {}
206+
for key, value in inputs_dict.items():
207+
if isinstance(value, torch.Tensor):
208+
inputs_on_device[key] = value.to(device)
209+
else:
210+
inputs_on_device[key] = value
211+
212+
cp_config = ContextParallelConfig(**cp_dict)
213+
model.enable_parallelism(config=cp_config)
214+
215+
with torch.no_grad():
216+
output = model(**inputs_on_device)
217+
if isinstance(output, dict):
218+
output = output.to_tuple()[0]
219+
220+
if rank == 0:
221+
result_queue.put(("success", output.shape))
222+
223+
except Exception as e:
224+
if rank == 0:
225+
result_queue.put(("error", str(e)))
226+
finally:
227+
if torch.distributed.is_initialized():
228+
torch.distributed.destroy_process_group()
229+
230+
231+
@is_context_parallel
232+
@require_torch_multi_accelerator
233+
class ContextParallelTesterMixin:
234+
base_precision = 1e-3
235+
236+
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
237+
def test_context_parallel_inference(self, cp_type):
238+
if not torch.distributed.is_available():
239+
pytest.skip("torch.distributed is not available.")
240+
241+
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
242+
pytest.skip("Context parallel requires at least 2 CUDA devices.")
243+
244+
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
245+
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
246+
247+
world_size = 2
248+
init_dict = self.get_init_dict()
249+
inputs_dict = self.get_dummy_inputs()
250+
cp_dict = {cp_type: world_size}
251+
252+
ctx = mp.get_context("spawn")
253+
result_queue = ctx.Queue()
254+
255+
mp.spawn(
256+
_context_parallel_worker,
257+
args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue),
258+
nprocs=world_size,
259+
join=True,
260+
)
261+
262+
status, result = result_queue.get(timeout=60)
263+
assert status == "success", f"Context parallel inference failed: {result}"

0 commit comments

Comments
 (0)