|
18 | 18 | import torch |
19 | 19 |
|
20 | 20 | from nemo_rl.distributed.model_utils import ( |
| 21 | + ChunkedDistributedEntropy, |
21 | 22 | ChunkedDistributedGatherLogprob, |
22 | 23 | ChunkedDistributedLogprob, |
23 | 24 | DistributedLogprob, |
@@ -705,14 +706,25 @@ def test_distributed_logprob_forward_and_backward(self): |
705 | 706 | full_logits.grad = None |
706 | 707 |
|
707 | 708 | # Compute distributed gradients |
708 | | - distributed_log_probs = DistributedLogprob.apply( |
709 | | - vocab_parallel_logits, |
710 | | - target, |
711 | | - vocab_start_index, |
712 | | - vocab_end_index, |
713 | | - self.tp_group, |
714 | | - False, # inference_only=False to enable backward |
715 | | - ) |
| 709 | + if chunk_size is not None: |
| 710 | + distributed_log_probs = ChunkedDistributedLogprob.apply( |
| 711 | + vocab_parallel_logits, |
| 712 | + target, |
| 713 | + vocab_start_index, |
| 714 | + vocab_end_index, |
| 715 | + chunk_size, |
| 716 | + self.tp_group, |
| 717 | + False, # inference_only=False to enable backward |
| 718 | + ) |
| 719 | + else: |
| 720 | + distributed_log_probs = DistributedLogprob.apply( |
| 721 | + vocab_parallel_logits, |
| 722 | + target, |
| 723 | + vocab_start_index, |
| 724 | + vocab_end_index, |
| 725 | + self.tp_group, |
| 726 | + False, # inference_only=False to enable backward |
| 727 | + ) |
716 | 728 |
|
717 | 729 | distributed_loss = torch.sum(distributed_log_probs) |
718 | 730 | distributed_loss.backward() |
@@ -956,3 +968,165 @@ def test_distributed_logprob_all_tests( |
956 | 968 |
|
957 | 969 | finally: |
958 | 970 | cluster.shutdown() |
| 971 | + |
| 972 | + |
| 973 | +@ray.remote(num_gpus=1) |
| 974 | +class ChunkedDistributedEntropyTestActor: |
| 975 | + def __init__(self, tp_size, chunk_size, inference_only): |
| 976 | + self.tp_size = tp_size |
| 977 | + self.chunk_size = chunk_size |
| 978 | + self.inference_only = inference_only |
| 979 | + self.env_vars = dict(os.environ) |
| 980 | + |
| 981 | + def test_chunked_distributed_entropy(self): |
| 982 | + torch.distributed.init_process_group(backend="nccl") |
| 983 | + |
| 984 | + rank = int(os.environ["RANK"]) |
| 985 | + tp_group = torch.distributed.new_group(ranks=list(range(self.tp_size))) |
| 986 | + |
| 987 | + batch_size = 2 |
| 988 | + seq_len = 16 |
| 989 | + vocab_size = 256 |
| 990 | + vocab_part_size = vocab_size // self.tp_size |
| 991 | + vocab_start_index = rank * vocab_part_size |
| 992 | + vocab_end_index = (rank + 1) * vocab_part_size |
| 993 | + |
| 994 | + torch.manual_seed(1337) |
| 995 | + full_logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda") |
| 996 | + |
| 997 | + # === Baseline: single-GPU entropy === |
| 998 | + baseline_logits = ( |
| 999 | + full_logits.clone().detach().requires_grad_(not self.inference_only) |
| 1000 | + ) |
| 1001 | + baseline_log_probs = torch.nn.functional.log_softmax(baseline_logits, dim=-1) |
| 1002 | + baseline_probs = baseline_log_probs.exp() |
| 1003 | + # H = sum_v p_v * log(p_v) (negative entropy; matches ChunkedDistributedEntropy) |
| 1004 | + baseline_entropy = (baseline_probs * baseline_log_probs).sum(dim=-1) # [B, S] |
| 1005 | + |
| 1006 | + if not self.inference_only: |
| 1007 | + baseline_entropy.sum().backward() |
| 1008 | + baseline_grad = baseline_logits.grad[ |
| 1009 | + :, :, vocab_start_index:vocab_end_index |
| 1010 | + ].clone() |
| 1011 | + |
| 1012 | + # === Distributed: ChunkedDistributedEntropy === |
| 1013 | + local_logits = full_logits[:, :, vocab_start_index:vocab_end_index] |
| 1014 | + local_logits = ( |
| 1015 | + local_logits.clone().detach().requires_grad_(not self.inference_only) |
| 1016 | + ) |
| 1017 | + |
| 1018 | + distributed_entropy = ChunkedDistributedEntropy.apply( |
| 1019 | + local_logits, |
| 1020 | + self.chunk_size, |
| 1021 | + tp_group, |
| 1022 | + self.inference_only, |
| 1023 | + ) |
| 1024 | + |
| 1025 | + # Forward check |
| 1026 | + torch.testing.assert_close( |
| 1027 | + distributed_entropy, baseline_entropy, rtol=1e-4, atol=1e-4 |
| 1028 | + ) |
| 1029 | + forward_diff = torch.max( |
| 1030 | + torch.abs(distributed_entropy - baseline_entropy) |
| 1031 | + ).item() |
| 1032 | + |
| 1033 | + # Backward check |
| 1034 | + if not self.inference_only: |
| 1035 | + distributed_entropy.sum().backward() |
| 1036 | + grad_local = local_logits.grad |
| 1037 | + torch.testing.assert_close(grad_local, baseline_grad, rtol=1e-4, atol=1e-4) |
| 1038 | + grad_diff = torch.max(torch.abs(grad_local - baseline_grad)).item() |
| 1039 | + else: |
| 1040 | + grad_diff = None |
| 1041 | + |
| 1042 | + return { |
| 1043 | + "forward_max_diff": forward_diff, |
| 1044 | + "grad_max_diff": grad_diff, |
| 1045 | + } |
| 1046 | + |
| 1047 | + |
| 1048 | +CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN = f"{ChunkedDistributedEntropyTestActor.__module__}.ChunkedDistributedEntropyTestActor" |
| 1049 | + |
| 1050 | + |
| 1051 | +@pytest.fixture |
| 1052 | +def register_chunked_distributed_entropy_test_actor(): |
| 1053 | + original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get( |
| 1054 | + CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN |
| 1055 | + ) |
| 1056 | + ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN] = ( |
| 1057 | + PY_EXECUTABLES.SYSTEM |
| 1058 | + ) |
| 1059 | + |
| 1060 | + yield CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN |
| 1061 | + |
| 1062 | + if CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY: |
| 1063 | + if original_registry_value is None: |
| 1064 | + del ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN] |
| 1065 | + else: |
| 1066 | + ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN] = ( |
| 1067 | + original_registry_value |
| 1068 | + ) |
| 1069 | + |
| 1070 | + |
| 1071 | +@pytest.mark.parametrize( |
| 1072 | + "tp_size, chunk_size, inference_only", |
| 1073 | + [ |
| 1074 | + (1, 5, False), |
| 1075 | + (2, 4, False), |
| 1076 | + (1, 3, True), |
| 1077 | + ], |
| 1078 | +) |
| 1079 | +def test_chunked_distributed_entropy( |
| 1080 | + register_chunked_distributed_entropy_test_actor, |
| 1081 | + tp_size, |
| 1082 | + chunk_size, |
| 1083 | + inference_only, |
| 1084 | +): |
| 1085 | + world_size = tp_size |
| 1086 | + |
| 1087 | + if not torch.cuda.is_available() or torch.cuda.device_count() < world_size: |
| 1088 | + pytest.skip( |
| 1089 | + f"Not enough GPUs available. Need {world_size}, " |
| 1090 | + f"got {torch.cuda.device_count()}" |
| 1091 | + ) |
| 1092 | + |
| 1093 | + cluster = RayVirtualCluster(bundle_ct_per_node_list=[world_size], use_gpus=True) |
| 1094 | + |
| 1095 | + try: |
| 1096 | + actor_fqn = register_chunked_distributed_entropy_test_actor |
| 1097 | + |
| 1098 | + sharding = NamedSharding( |
| 1099 | + layout=np.arange(world_size).reshape(tp_size), names=["tp"] |
| 1100 | + ) |
| 1101 | + builder = RayWorkerBuilder(actor_fqn, tp_size, chunk_size, inference_only) |
| 1102 | + |
| 1103 | + worker_group = RayWorkerGroup( |
| 1104 | + cluster=cluster, |
| 1105 | + remote_worker_builder=builder, |
| 1106 | + workers_per_node=None, |
| 1107 | + sharding_annotations=sharding, |
| 1108 | + ) |
| 1109 | + |
| 1110 | + futures = worker_group.run_all_workers_single_data( |
| 1111 | + "test_chunked_distributed_entropy" |
| 1112 | + ) |
| 1113 | + results = ray.get(futures) |
| 1114 | + |
| 1115 | + for i, result in enumerate(results): |
| 1116 | + print(f"Worker {i} forward max diff: {result['forward_max_diff']:.2e}") |
| 1117 | + assert result["forward_max_diff"] < 1e-4, ( |
| 1118 | + f"Worker {i} forward diff too large: {result['forward_max_diff']}" |
| 1119 | + ) |
| 1120 | + if not inference_only: |
| 1121 | + print(f"Worker {i} grad max diff: {result['grad_max_diff']:.2e}") |
| 1122 | + assert ( |
| 1123 | + result["grad_max_diff"] is not None |
| 1124 | + and result["grad_max_diff"] < 1e-4 |
| 1125 | + ), f"Worker {i} grad diff too large: {result['grad_max_diff']}" |
| 1126 | + else: |
| 1127 | + assert result["grad_max_diff"] is None |
| 1128 | + |
| 1129 | + worker_group.shutdown(force=True) |
| 1130 | + |
| 1131 | + finally: |
| 1132 | + cluster.shutdown() |
0 commit comments