|
2 | 2 | import importlib |
3 | 3 | import platform |
4 | 4 | import unittest |
5 | | -from typing import Optional |
6 | 5 |
|
7 | 6 | import pytest |
8 | 7 | import torch |
9 | | -import torch.nn as nn |
10 | 8 | import torch_tensorrt as torchtrt |
11 | 9 | from torch_tensorrt.dynamo.utils import ( |
12 | 10 | COSINE_THRESHOLD, |
@@ -422,104 +420,6 @@ def test_resnet18_half(ir): |
422 | 420 | torch._dynamo.reset() |
423 | 421 |
|
424 | 422 |
|
425 | | -@pytest.mark.unit |
426 | | -def test_cosmos_true_div(ir): |
427 | | - class CosmosLearnablePositionalEmbed(torch.nn.Module): |
428 | | - def __init__( |
429 | | - self, |
430 | | - hidden_size: int, |
431 | | - max_size: tuple[int, int, int], |
432 | | - patch_size: tuple[int, int, int], |
433 | | - eps: float = 1e-6, |
434 | | - ) -> None: |
435 | | - super().__init__() |
436 | | - |
437 | | - self.max_size = [size // patch for size, patch in zip(max_size, patch_size)] |
438 | | - self.patch_size = patch_size |
439 | | - self.eps = eps |
440 | | - |
441 | | - self.pos_emb_t = nn.Parameter(torch.randn(self.max_size[0], hidden_size)) |
442 | | - self.pos_emb_h = nn.Parameter(torch.randn(self.max_size[1], hidden_size)) |
443 | | - self.pos_emb_w = nn.Parameter(torch.randn(self.max_size[2], hidden_size)) |
444 | | - |
445 | | - def forward( |
446 | | - self, |
447 | | - hidden_states: torch.Tensor, |
448 | | - num_ranks: Optional[int] = None, |
449 | | - rank_id: Optional[torch.Tensor] = None, |
450 | | - ) -> torch.Tensor: |
451 | | - batch_size, num_channels, num_frames, height, width = hidden_states.shape |
452 | | - pe_size = [ |
453 | | - num_frames // self.patch_size[0], |
454 | | - height // self.patch_size[1], |
455 | | - width // self.patch_size[2], |
456 | | - ] |
457 | | - if num_ranks is not None and rank_id is not None: |
458 | | - pe_size[0] = pe_size[0] * num_ranks |
459 | | - |
460 | | - # Use expand() instead of repeat() - torch_tensorrt compatible |
461 | | - # expand() creates a view without copying data, better for dynamic shapes |
462 | | - emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].expand( |
463 | | - batch_size, -1, pe_size[1], pe_size[2], -1 |
464 | | - ) |
465 | | - emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].expand( |
466 | | - batch_size, pe_size[0], -1, pe_size[2], -1 |
467 | | - ) |
468 | | - emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].expand( |
469 | | - batch_size, pe_size[0], pe_size[1], -1, -1 |
470 | | - ) |
471 | | - emb = emb_t + emb_h + emb_w |
472 | | - emb = emb.flatten(1, 3) |
473 | | - |
474 | | - norm = torch.linalg.vector_norm( |
475 | | - emb, dim=-1, keepdim=True, dtype=torch.float32 |
476 | | - ) |
477 | | - alpha = (norm.numel() / emb.numel()) ** 0.5 |
478 | | - # hidden_size = emb.shape[-1] |
479 | | - # alpha = (1.0 / hidden_size) ** 0.5 |
480 | | - norm = torch.add(self.eps, norm, alpha=alpha) |
481 | | - return (emb / norm).type_as(hidden_states) |
482 | | - |
483 | | - with torch.no_grad(): |
484 | | - hidden_states = torch.randn(1, 16, 16, 88, 160).cuda() |
485 | | - model = CosmosLearnablePositionalEmbed( |
486 | | - hidden_size=4096, |
487 | | - max_size=(128, 240, 240), |
488 | | - patch_size=(1, 2, 2), |
489 | | - ) |
490 | | - model.eval().cuda() |
491 | | - pyt_output = model(hidden_states) |
492 | | - num_latent_frames = torch.export.Dim("num_latent_frames", min=1, max=16) |
493 | | - |
494 | | - ep = torch.export.export( |
495 | | - model, |
496 | | - args=(hidden_states,), |
497 | | - dynamic_shapes=({2: num_latent_frames},), # Make dimension 2 dynamic |
498 | | - strict=False, |
499 | | - ) |
500 | | - trt_model = torchtrt.dynamo.compile( |
501 | | - ep, |
502 | | - inputs=(hidden_states,), |
503 | | - enabled_precisions={torch.bfloat16}, |
504 | | - use_explicit_typing=False, |
505 | | - use_fp32_acc=False, |
506 | | - device="cuda:0", |
507 | | - disable_tf32=True, |
508 | | - use_python_runtime=True, |
509 | | - min_block_size=1, |
510 | | - ) |
511 | | - trt_output = trt_model(hidden_states) |
512 | | - |
513 | | - cos_sim = cosine_similarity(pyt_output, trt_output) |
514 | | - assertions.assertTrue( |
515 | | - cos_sim > COSINE_THRESHOLD, |
516 | | - msg=f"Cosmos Learnable Positional Embed TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", |
517 | | - ) |
518 | | - |
519 | | - # Clean up model env |
520 | | - torch._dynamo.reset() |
521 | | - |
522 | | - |
523 | 423 | @pytest.mark.unit |
524 | 424 | @unittest.skipIf( |
525 | 425 | torchtrt.ENABLED_FEATURES.tensorrt_rtx, |
|
0 commit comments