Skip to content

Commit e5c2de4

Browse files
committed
save compare script
1 parent 7e0223f commit e5c2de4

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

compare_tensors.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Compare tensor files between mixedTP2CP2_redo and pureTP4_redo directories.
4+
5+
Mapping (mixedTP2CP2 -> pureTP4):
6+
rank 0 -> rank 0
7+
rank 1 -> rank 2
8+
rank 2 -> rank 1
9+
rank 3 -> rank 3
10+
"""
11+
12+
import torch
13+
import os
14+
from pathlib import Path
15+
16+
# Directories
17+
mixed_dir = Path("/home/bbuddharaju/scratch/TensorRT-LLM/mixedTP2CP2_redo")
18+
pure_dir = Path("/home/bbuddharaju/scratch/TensorRT-LLM/pureTP4_redo")
19+
20+
# Mapping: mixed_rank -> pure_rank
21+
rank_mapping = {
22+
0: 0,
23+
1: 2,
24+
2: 1,
25+
3: 3,
26+
}
27+
28+
# File patterns for before_o_proj.pt
29+
# mixedTP2CP2: rank0_cp0_tp0, rank1_cp0_tp1, rank2_cp1_tp0, rank3_cp1_tp1
30+
# pureTP4: rank0_cp0_tp0, rank1_cp0_tp1, rank2_cp0_tp2, rank3_cp0_tp3
31+
32+
mixed_files = {
33+
0: "rank0_cp0_tp0_before_o_proj.pt",
34+
1: "rank1_cp0_tp1_before_o_proj.pt",
35+
2: "rank2_cp1_tp0_before_o_proj.pt",
36+
3: "rank3_cp1_tp1_before_o_proj.pt",
37+
}
38+
39+
pure_files = {
40+
0: "rank0_cp0_tp0_before_o_proj.pt",
41+
1: "rank1_cp0_tp1_before_o_proj.pt",
42+
2: "rank2_cp0_tp2_before_o_proj.pt",
43+
3: "rank3_cp0_tp3_before_o_proj.pt",
44+
}
45+
46+
print("=" * 70)
47+
print("Comparing tensors: mixedTP2CP2_redo -> pureTP4_redo")
48+
print("=" * 70)
49+
50+
for mixed_rank, pure_rank in rank_mapping.items():
51+
mixed_file = mixed_dir / mixed_files[mixed_rank]
52+
pure_file = pure_dir / pure_files[pure_rank]
53+
54+
print(f"\n[mixed rank {mixed_rank}] -> [pure rank {pure_rank}]")
55+
print(f" Mixed file: {mixed_files[mixed_rank]}")
56+
print(f" Pure file: {pure_files[pure_rank]}")
57+
58+
if not mixed_file.exists():
59+
print(f" ERROR: Mixed file not found: {mixed_file}")
60+
continue
61+
if not pure_file.exists():
62+
print(f" ERROR: Pure file not found: {pure_file}")
63+
continue
64+
65+
# Load tensors
66+
mixed_tensor = torch.load(mixed_file, map_location='cpu', weights_only=True)
67+
pure_tensor = torch.load(pure_file, map_location='cpu', weights_only=True)
68+
69+
# Handle case where loaded data might be a dict
70+
if isinstance(mixed_tensor, dict):
71+
print(f" Mixed tensor is a dict with keys: {list(mixed_tensor.keys())}")
72+
mixed_tensor = list(mixed_tensor.values())[0]
73+
if isinstance(pure_tensor, dict):
74+
print(f" Pure tensor is a dict with keys: {list(pure_tensor.keys())}")
75+
pure_tensor = list(pure_tensor.values())[0]
76+
77+
print(f" Mixed shape: {mixed_tensor.shape}, dtype: {mixed_tensor.dtype}")
78+
print(f" Pure shape: {pure_tensor.shape}, dtype: {pure_tensor.dtype}")
79+
80+
if mixed_tensor.shape != pure_tensor.shape:
81+
print(f" WARNING: Shape mismatch!")
82+
continue
83+
84+
# Convert to float for comparison
85+
mixed_float = mixed_tensor.float()
86+
pure_float = pure_tensor.float()
87+
88+
# Compute differences
89+
diff = torch.abs(mixed_float - pure_float)
90+
mean_diff = diff.mean().item()
91+
max_diff = diff.max().item()
92+
93+
# Also compute relative differences where pure_tensor is non-zero
94+
non_zero_mask = pure_float.abs() > 1e-8
95+
if non_zero_mask.any():
96+
rel_diff = diff[non_zero_mask] / pure_float.abs()[non_zero_mask]
97+
mean_rel_diff = rel_diff.mean().item()
98+
max_rel_diff = rel_diff.max().item()
99+
else:
100+
mean_rel_diff = float('nan')
101+
max_rel_diff = float('nan')
102+
103+
print(f" Mean absolute diff: {mean_diff:.6e}")
104+
print(f" Max absolute diff: {max_diff:.6e}")
105+
print(f" Mean relative diff: {mean_rel_diff:.6e}")
106+
print(f" Max relative diff: {max_rel_diff:.6e}")
107+
108+
print("\n" + "=" * 70)
109+
print("Comparison complete.")
110+
print("=" * 70)
111+

0 commit comments

Comments
 (0)