1
1
#!/usr/bin/env python3
2
- # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
3
7
4
8
# pyre-strict
5
9
@@ -28,7 +32,7 @@ class TestMCH(unittest.TestCase):
28
32
# pyre-ignore[56]
29
33
@unittest .skipIf (
30
34
torch .cuda .device_count () < 1 ,
31
- "Not enough GPUs, this test requires at least two GPUs " ,
35
+ "Not enough GPUs, this test requires at least one GPU " ,
32
36
)
33
37
def test_zch_hash_inference (self ) -> None :
34
38
# prepare
@@ -143,11 +147,6 @@ def test_zch_hash_inference(self) -> None:
143
147
f"{ torch .unique (m3 ._hash_zch_identities )= } " ,
144
148
)
145
149
146
- # pyre-ignore[56]
147
- @unittest .skipIf (
148
- torch .cuda .device_count () < 1 ,
149
- "This test requires CUDA device" ,
150
- )
151
150
def test_scriptability (self ) -> None :
152
151
zch_size = 10
153
152
mc_modules = {
@@ -180,11 +179,6 @@ def test_scriptability(self) -> None:
180
179
)
181
180
torch .jit .script (mcc_ec )
182
181
183
- # pyre-ignore[56]
184
- @unittest .skipIf (
185
- torch .cuda .device_count () < 1 ,
186
- "This test requires CUDA device" ,
187
- )
188
182
def test_scriptability_lru (self ) -> None :
189
183
zch_size = 10
190
184
mc_modules = {
@@ -219,13 +213,13 @@ def test_scriptability_lru(self) -> None:
219
213
torch .jit .script (mcc_ec )
220
214
221
215
@unittest .skipIf (
222
- torch .cuda .device_count () < 1 ,
223
- "Not enough GPUs, this test requires at least one GPUs" ,
216
+ torch .cuda .device_count () < 2 ,
217
+ "Not enough GPUs, this test requires at least two GPUs" ,
224
218
)
225
219
# pyre-ignore [56]
226
220
@given (hash_size = st .sampled_from ([0 , 80 ]), keep_original_indices = st .booleans ())
227
221
@settings (max_examples = 6 , deadline = None )
228
- def test_zch_hash_train_to_inf_block_bucketize (
222
+ def test_zch_hash_train_to_inf_block_bucketize_disabled_in_oss_compatibility (
229
223
self , hash_size : int , keep_original_indices : bool
230
224
) -> None :
231
225
# rank 0
@@ -298,13 +292,15 @@ def test_zch_hash_train_to_inf_block_bucketize(
298
292
)
299
293
300
294
@unittest .skipIf (
301
- torch .cuda .device_count () < 1 ,
302
- "Not enough GPUs, this test requires at least one GPUs" ,
295
+ torch .cuda .device_count () < 2 ,
296
+ "Not enough GPUs, this test requires at least two GPUs" ,
303
297
)
304
298
# pyre-ignore [56]
305
299
@given (hash_size = st .sampled_from ([0 , 80 ]))
306
300
@settings (max_examples = 5 , deadline = None )
307
- def test_zch_hash_train_rescales_two (self , hash_size : int ) -> None :
301
+ def test_zch_hash_train_rescales_two_disabled_in_oss_compatibility (
302
+ self , hash_size : int
303
+ ) -> None :
308
304
keep_original_indices = False
309
305
# rank 0
310
306
world_size = 2
@@ -410,13 +406,13 @@ def test_zch_hash_train_rescales_two(self, hash_size: int) -> None:
410
406
)
411
407
412
408
@unittest .skipIf (
413
- torch .cuda .device_count () < 1 ,
409
+ torch .cuda .device_count () < 2 ,
414
410
"Not enough GPUs, this test requires at least one GPUs" ,
415
411
)
416
412
# pyre-ignore [56]
417
413
@given (hash_size = st .sampled_from ([0 , 80 ]))
418
414
@settings (max_examples = 5 , deadline = None )
419
- def test_zch_hash_train_rescales_four (self , hash_size : int ) -> None :
415
+ def test_zch_hash_train_rescales_one (self , hash_size : int ) -> None :
420
416
keep_original_indices = True
421
417
kjt = KeyedJaggedTensor (
422
418
keys = ["f" ],
@@ -452,23 +448,20 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
452
448
),
453
449
)
454
450
455
- # start with world_size = 4
456
- world_size = 4
451
+ # start with world_size = 2
452
+ world_size = 2
457
453
block_sizes = torch .tensor (
458
454
[(size + world_size - 1 ) // world_size for size in [hash_size ]],
459
455
dtype = torch .int64 ,
460
456
device = "cuda" ,
461
457
)
462
458
463
- m1_1 = m0 .rebuild_with_output_id_range ((0 , 10 ))
464
- m2_1 = m0 .rebuild_with_output_id_range ((10 , 20 ))
465
- m3_1 = m0 .rebuild_with_output_id_range ((20 , 30 ))
466
- m4_1 = m0 .rebuild_with_output_id_range ((30 , 40 ))
459
+ m1_1 = m0 .rebuild_with_output_id_range ((0 , 20 ))
460
+ m2_1 = m0 .rebuild_with_output_id_range ((20 , 40 ))
467
461
468
- # shard, now world size 2!
469
- # start with world_size = 4
462
+ # shard, now world size 1!
470
463
if hash_size > 0 :
471
- world_size = 2
464
+ world_size = 1
472
465
block_sizes = torch .tensor (
473
466
[(size + world_size - 1 ) // world_size for size in [hash_size ]],
474
467
dtype = torch .int64 ,
@@ -482,7 +475,7 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
482
475
keep_original_indices = keep_original_indices ,
483
476
output_permute = True ,
484
477
)
485
- in1_2 , in2_2 = bucketized_kjt .split ([len (kjt .keys ())] * world_size )
478
+ in1_2 = bucketized_kjt .split ([len (kjt .keys ())] * world_size )[ 0 ]
486
479
else :
487
480
bucketized_kjt , permute = bucketize_kjt_before_all2all (
488
481
kjt ,
@@ -498,14 +491,8 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
498
491
values = torch .cat ([kjts [0 ].values (), kjts [1 ].values ()], dim = 0 ),
499
492
lengths = torch .cat ([kjts [0 ].lengths (), kjts [1 ].lengths ()], dim = 0 ),
500
493
)
501
- in2_2 = KeyedJaggedTensor (
502
- keys = kjts [2 ].keys (),
503
- values = torch .cat ([kjts [2 ].values (), kjts [3 ].values ()], dim = 0 ),
504
- lengths = torch .cat ([kjts [2 ].lengths (), kjts [3 ].lengths ()], dim = 0 ),
505
- )
506
494
507
- m1_2 = m0 .rebuild_with_output_id_range ((0 , 20 ))
508
- m2_2 = m0 .rebuild_with_output_id_range ((20 , 40 ))
495
+ m1_2 = m0 .rebuild_with_output_id_range ((0 , 40 ))
509
496
m1_zch_identities = torch .cat (
510
497
[
511
498
m1_1 .state_dict ()["_hash_zch_identities" ],
@@ -522,53 +509,30 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
522
509
state_dict ["_hash_zch_identities" ] = m1_zch_identities
523
510
state_dict ["_hash_zch_metadata" ] = m1_zch_metadata
524
511
m1_2 .load_state_dict (state_dict )
525
-
526
- m2_zch_identities = torch .cat (
527
- [
528
- m3_1 .state_dict ()["_hash_zch_identities" ],
529
- m4_1 .state_dict ()["_hash_zch_identities" ],
530
- ]
531
- )
532
- m2_zch_metadata = torch .cat (
533
- [
534
- m3_1 .state_dict ()["_hash_zch_metadata" ],
535
- m4_1 .state_dict ()["_hash_zch_metadata" ],
536
- ]
537
- )
538
- state_dict = m2_2 .state_dict ()
539
- state_dict ["_hash_zch_identities" ] = m2_zch_identities
540
- state_dict ["_hash_zch_metadata" ] = m2_zch_metadata
541
- m2_2 .load_state_dict (state_dict )
542
-
543
512
_ = m1_2 (in1_2 .to_dict ())
544
- _ = m2_2 (in2_2 .to_dict ())
545
513
546
514
m0 .reset_inference_mode () # just clears out training state
547
515
full_zch_identities = torch .cat (
548
516
[
549
517
m1_2 .state_dict ()["_hash_zch_identities" ],
550
- m2_2 .state_dict ()["_hash_zch_identities" ],
551
518
]
552
519
)
553
520
state_dict = m0 .state_dict ()
554
521
state_dict ["_hash_zch_identities" ] = full_zch_identities
555
522
m0 .load_state_dict (state_dict )
556
523
557
- # now set all models to eval, and run kjt
558
524
m1_2 .eval ()
559
- m2_2 .eval ()
560
525
assert m0 .training is False
561
526
562
527
inf_input = kjt .to_dict ()
563
- inf_output = m0 (inf_input )
564
528
529
+ inf_output = m0 (inf_input )
565
530
o1_2 = m1_2 (in1_2 .to_dict ())
566
- o2_2 = m2_2 (in2_2 .to_dict ())
567
531
self .assertTrue (
568
532
torch .allclose (
569
533
inf_output ["f" ].values (),
570
534
torch .index_select (
571
- torch . cat ([ x [ "f" ].values () for x in [ o1_2 , o2_2 ]] ),
535
+ o1_2 [ "f" ].values (),
572
536
dim = 0 ,
573
537
index = cast (torch .Tensor , permute ),
574
538
),
@@ -578,7 +542,7 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
578
542
# pyre-ignore[56]
579
543
@unittest .skipIf (
580
544
torch .cuda .device_count () < 1 ,
581
- "This test requires CUDA device " ,
545
+ "This test requires at least one GPU " ,
582
546
)
583
547
def test_output_global_offset_tensor (self ) -> None :
584
548
m = HashZchManagedCollisionModule (
@@ -653,7 +617,7 @@ def test_output_global_offset_tensor(self) -> None:
653
617
# pyre-ignore[56]
654
618
@unittest .skipIf (
655
619
torch .cuda .device_count () < 1 ,
656
- "This test requires CUDA device " ,
620
+ "This test requires at least one GPU " ,
657
621
)
658
622
def test_dynamically_switch_inference_training_mode (self ) -> None :
659
623
m = HashZchManagedCollisionModule (
0 commit comments