@@ -29,14 +29,14 @@ def test_sizes_and_ranks(multidevice_test):
2929@pytest .mark .mpi
3030def test_pointwise (multidevice_test ):
3131 num_devices = multidevice_test .size
32- mesh = nvfuser .multidevice .DeviceMesh (torch .arange (num_devices ))
3332
3433 with FusionDefinition () as fd :
3534 inp_tv = fd .define_tensor ((- 1 , - 1 ), contiguity = False , dtype = DataType .Float )
3635 tv1 = fd .ops .relu (inp_tv )
3736 tv2 = fd .ops .add (tv1 , tv1 )
3837 fd .add_output (tv2 )
3938
39+ mesh = nvfuser .multidevice .DeviceMesh (torch .arange (num_devices ))
4040 for tv in [inp_tv , tv1 , tv2 ]:
4141 tv .set_device_mesh (mesh )
4242
@@ -50,6 +50,63 @@ def test_pointwise(multidevice_test):
5050 torch .testing .assert_close (out .cpu (), out_ref )
5151
5252
53+ @pytest .mark .mpi
54+ def test_transpose (multidevice_test ):
55+ d = multidevice_test .size
56+ cp_size = 2
57+ if d % (cp_size * cp_size ) != 0 :
58+ pytest .skip (
59+ f"We only support even split, so { d } has to be divisible by { cp_size * cp_size } for { cp_size = } ."
60+ )
61+ dp_size = d // (cp_size * cp_size )
62+
63+ c = 128
64+ with FusionDefinition () as fd :
65+ inp_tv = fd .define_tensor (
66+ (- 1 , c , - 1 , - 1 , cp_size ), contiguity = True , dtype = DataType .BFloat16
67+ )
68+ out_tv = fd .ops .set (inp_tv )
69+ fd .add_output (out_tv )
70+
71+ mesh = nvfuser .multidevice .DeviceMesh (
72+ torch .arange (d ).reshape (dp_size , cp_size , cp_size )
73+ )
74+ for tv in [inp_tv , out_tv ]:
75+ tv .set_device_mesh (mesh )
76+
77+ inp_tv .axis (4 ).parallelize (nvfuser .ParallelType .mesh_y )
78+ inp_tv .outer_split (3 , cp_size )
79+ inp_tv .axis (3 ).parallelize (nvfuser .ParallelType .mesh_x )
80+ inp_tv .outer_split (0 , dp_size )
81+ inp_tv .axis (0 ).parallelize (nvfuser .ParallelType .mesh_z )
82+
83+ out_tv .axis (4 ).parallelize (nvfuser .ParallelType .mesh_y )
84+ out_tv .outer_split (3 , cp_size )
85+ out_tv .axis (3 ).parallelize (nvfuser .ParallelType .mesh_x )
86+ out_tv .outer_split (0 , dp_size )
87+ out_tv .axis (0 ).parallelize (nvfuser .ParallelType .mesh_z )
88+ out_tv .set_allocation_domain (
89+ (
90+ out_tv .axis (3 ),
91+ out_tv .axis (0 ),
92+ out_tv .axis (1 ),
93+ out_tv .axis (2 ),
94+ out_tv .axis (4 ),
95+ out_tv .axis (5 ),
96+ out_tv .axis (6 ),
97+ ),
98+ True ,
99+ )
100+
101+ b = dp_size * 3
102+ s = cp_size * 5
103+ inp_ref = torch .randn (b , c , s , s , cp_size , dtype = torch .bfloat16 )
104+ out_ref = inp_ref
105+
106+ inp = multidevice_test .shard_tensor (inp_ref , inp_tv )
107+ fd .execute ([inp ])
108+
109+
53110class QkvFormat (Enum ):
54111 BHSE = auto ()
55112 BSHE = auto ()
0 commit comments