18
18
EmbeddingTowerSharder ,
19
19
)
20
20
from torchrec .distributed .embedding_types import EmbeddingComputeKernel
21
- from torchrec .distributed .embeddingbag import (
22
- EmbeddingBagCollection ,
23
- EmbeddingBagCollectionSharder ,
24
- )
21
+ from torchrec .distributed .embeddingbag import EmbeddingBagCollectionSharder
25
22
from torchrec .distributed .mc_embeddingbag import (
26
23
ManagedCollisionEmbeddingBagCollectionSharder ,
27
24
)
48
45
[[17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [11 , 80 ]],
49
46
]
50
47
51
- EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS = [
52
- [[20 , 20 ], [20 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ]],
53
- [[22 , 40 ], [22 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ]],
54
- [[24 , 60 ], [24 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ]],
55
- [[26 , 80 ], [26 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ]],
56
- ]
57
-
58
48
EXPECTED_RW_SHARD_OFFSETS = [
59
49
[[0 , 0 ], [13 , 0 ], [26 , 0 ], [39 , 0 ], [52 , 0 ], [65 , 0 ], [78 , 0 ], [91 , 0 ]],
60
50
[[0 , 0 ], [14 , 0 ], [28 , 0 ], [42 , 0 ], [56 , 0 ], [70 , 0 ], [84 , 0 ], [98 , 0 ]],
61
51
[[0 , 0 ], [15 , 0 ], [30 , 0 ], [45 , 0 ], [60 , 0 ], [75 , 0 ], [90 , 0 ], [105 , 0 ]],
62
52
[[0 , 0 ], [17 , 0 ], [34 , 0 ], [51 , 0 ], [68 , 0 ], [85 , 0 ], [102 , 0 ], [119 , 0 ]],
63
53
]
64
54
65
- EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS = [
66
- [[0 , 0 ], [20 , 0 ], [40 , 0 ], [50 , 0 ], [60 , 0 ], [70 , 0 ], [80 , 0 ], [90 , 0 ]],
67
- [[0 , 0 ], [22 , 0 ], [44 , 0 ], [55 , 0 ], [66 , 0 ], [77 , 0 ], [88 , 0 ], [99 , 0 ]],
68
- [[0 , 0 ], [24 , 0 ], [48 , 0 ], [60 , 0 ], [72 , 0 ], [84 , 0 ], [96 , 0 ], [108 , 0 ]],
69
- [[0 , 0 ], [26 , 0 ], [52 , 0 ], [65 , 0 ], [78 , 0 ], [91 , 0 ], [104 , 0 ], [117 , 0 ]],
70
- ]
71
-
72
55
73
56
def get_expected_cache_aux_size (rows : int ) -> int :
74
57
# 0.2 is the hardcoded cache load factor assumed in this test
@@ -118,48 +101,6 @@ def get_expected_cache_aux_size(rows: int) -> int:
118
101
],
119
102
]
120
103
121
- EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS = [
122
- [
123
- Storage (hbm = 165888 , ddr = 0 ),
124
- Storage (hbm = 165888 , ddr = 0 ),
125
- Storage (hbm = 165888 , ddr = 0 ),
126
- Storage (hbm = 165888 , ddr = 0 ),
127
- Storage (hbm = 165888 , ddr = 0 ),
128
- Storage (hbm = 165888 , ddr = 0 ),
129
- Storage (hbm = 165888 , ddr = 0 ),
130
- Storage (hbm = 165888 , ddr = 0 ),
131
- ],
132
- [
133
- Storage (hbm = 1001472 , ddr = 0 ),
134
- Storage (hbm = 1001472 , ddr = 0 ),
135
- Storage (hbm = 1001472 , ddr = 0 ),
136
- Storage (hbm = 1001472 , ddr = 0 ),
137
- Storage (hbm = 1001472 , ddr = 0 ),
138
- Storage (hbm = 1001472 , ddr = 0 ),
139
- Storage (hbm = 1001472 , ddr = 0 ),
140
- Storage (hbm = 1001472 , ddr = 0 ),
141
- ],
142
- [
143
- Storage (hbm = 1003520 , ddr = 0 ),
144
- Storage (hbm = 1003520 , ddr = 0 ),
145
- Storage (hbm = 1003520 , ddr = 0 ),
146
- Storage (hbm = 1003520 , ddr = 0 ),
147
- Storage (hbm = 1003520 , ddr = 0 ),
148
- Storage (hbm = 1003520 , ddr = 0 ),
149
- Storage (hbm = 1003520 , ddr = 0 ),
150
- Storage (hbm = 1003520 , ddr = 0 ),
151
- ],
152
- [
153
- Storage (hbm = 2648064 , ddr = 0 ),
154
- Storage (hbm = 2648064 , ddr = 0 ),
155
- Storage (hbm = 2648064 , ddr = 0 ),
156
- Storage (hbm = 2648064 , ddr = 0 ),
157
- Storage (hbm = 2648064 , ddr = 0 ),
158
- Storage (hbm = 2648064 , ddr = 0 ),
159
- Storage (hbm = 2648064 , ddr = 0 ),
160
- Storage (hbm = 2648064 , ddr = 0 ),
161
- ],
162
- ]
163
104
164
105
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [
165
106
[
@@ -204,48 +145,6 @@ def get_expected_cache_aux_size(rows: int) -> int:
204
145
],
205
146
]
206
147
207
- EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS = [
208
- [
209
- Storage (hbm = 166352 , ddr = 1600 ),
210
- Storage (hbm = 166352 , ddr = 1600 ),
211
- Storage (hbm = 166120 , ddr = 800 ),
212
- Storage (hbm = 166120 , ddr = 800 ),
213
- Storage (hbm = 166120 , ddr = 800 ),
214
- Storage (hbm = 166120 , ddr = 800 ),
215
- Storage (hbm = 166120 , ddr = 800 ),
216
- Storage (hbm = 166120 , ddr = 800 ),
217
- ],
218
- [
219
- Storage (hbm = 1002335 , ddr = 3520 ),
220
- Storage (hbm = 1002335 , ddr = 3520 ),
221
- Storage (hbm = 1001904 , ddr = 1760 ),
222
- Storage (hbm = 1001904 , ddr = 1760 ),
223
- Storage (hbm = 1001904 , ddr = 1760 ),
224
- Storage (hbm = 1001904 , ddr = 1760 ),
225
- Storage (hbm = 1001904 , ddr = 1760 ),
226
- Storage (hbm = 1001904 , ddr = 1760 ),
227
- ],
228
- [
229
- Storage (hbm = 1004845 , ddr = 5760 ),
230
- Storage (hbm = 1004845 , ddr = 5760 ),
231
- Storage (hbm = 1004183 , ddr = 2880 ),
232
- Storage (hbm = 1004183 , ddr = 2880 ),
233
- Storage (hbm = 1004183 , ddr = 2880 ),
234
- Storage (hbm = 1004183 , ddr = 2880 ),
235
- Storage (hbm = 1004183 , ddr = 2880 ),
236
- Storage (hbm = 1004183 , ddr = 2880 ),
237
- ],
238
- [
239
- Storage (hbm = 2649916 , ddr = 8320 ),
240
- Storage (hbm = 2649916 , ddr = 8320 ),
241
- Storage (hbm = 2648990 , ddr = 4160 ),
242
- Storage (hbm = 2648990 , ddr = 4160 ),
243
- Storage (hbm = 2648990 , ddr = 4160 ),
244
- Storage (hbm = 2648990 , ddr = 4160 ),
245
- Storage (hbm = 2648990 , ddr = 4160 ),
246
- Storage (hbm = 2648990 , ddr = 4160 ),
247
- ],
248
- ]
249
148
250
149
EXPECTED_TWRW_SHARD_SIZES = [
251
150
[[25 , 20 ], [25 , 20 ], [25 , 20 ], [25 , 20 ]],
@@ -349,16 +248,6 @@ def compute_kernels(
349
248
return [EmbeddingComputeKernel .FUSED .value ]
350
249
351
250
352
- class VirtualTableRWSharder (EmbeddingBagCollectionSharder ):
353
- def sharding_types (self , compute_device_type : str ) -> List [str ]:
354
- return [ShardingType .ROW_WISE .value ]
355
-
356
- def compute_kernels (
357
- self , sharding_type : str , compute_device_type : str
358
- ) -> List [str ]:
359
- return [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ]
360
-
361
-
362
251
class UVMCachingRWSharder (EmbeddingBagCollectionSharder ):
363
252
def sharding_types (self , compute_device_type : str ) -> List [str ]:
364
253
return [ShardingType .ROW_WISE .value ]
@@ -468,27 +357,6 @@ def setUp(self) -> None:
468
357
min_partition = 40 , pooling_factors = [2 , 1 , 3 , 7 ]
469
358
),
470
359
}
471
- self ._virtual_table_constraints = {
472
- "table_0" : ParameterConstraints (
473
- min_partition = 20 ,
474
- compute_kernels = [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ],
475
- ),
476
- "table_1" : ParameterConstraints (
477
- min_partition = 20 ,
478
- pooling_factors = [1 , 3 , 5 ],
479
- compute_kernels = [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ],
480
- ),
481
- "table_2" : ParameterConstraints (
482
- min_partition = 20 ,
483
- pooling_factors = [8 , 2 ],
484
- compute_kernels = [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ],
485
- ),
486
- "table_3" : ParameterConstraints (
487
- min_partition = 40 ,
488
- pooling_factors = [2 , 1 , 3 , 7 ],
489
- compute_kernels = [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ],
490
- ),
491
- }
492
360
self .num_tables = 4
493
361
tables = [
494
362
EmbeddingBagConfig (
@@ -499,17 +367,6 @@ def setUp(self) -> None:
499
367
)
500
368
for i in range (self .num_tables )
501
369
]
502
- tables_with_buckets = [
503
- EmbeddingBagConfig (
504
- num_embeddings = 100 + i * 10 ,
505
- embedding_dim = 20 + i * 20 ,
506
- name = "table_" + str (i ),
507
- feature_names = ["feature_" + str (i )],
508
- total_num_buckets = 10 ,
509
- use_virtual_table = True ,
510
- )
511
- for i in range (self .num_tables )
512
- ]
513
370
weighted_tables = [
514
371
EmbeddingBagConfig (
515
372
num_embeddings = (i + 1 ) * 10 ,
@@ -520,9 +377,6 @@ def setUp(self) -> None:
520
377
for i in range (4 )
521
378
]
522
379
self .model = TestSparseNN (tables = tables , weighted_tables = [])
523
- self .model_with_buckets = EmbeddingBagCollection (
524
- tables = tables_with_buckets ,
525
- )
526
380
self .enumerator = EmbeddingEnumerator (
527
381
topology = Topology (
528
382
world_size = self .world_size ,
@@ -532,15 +386,6 @@ def setUp(self) -> None:
532
386
batch_size = self .batch_size ,
533
387
constraints = self .constraints ,
534
388
)
535
- self .virtual_table_enumerator = EmbeddingEnumerator (
536
- topology = Topology (
537
- world_size = self .world_size ,
538
- compute_device = self .compute_device ,
539
- local_world_size = self .local_world_size ,
540
- ),
541
- batch_size = self .batch_size ,
542
- constraints = self ._virtual_table_constraints ,
543
- )
544
389
self .tower_model = TestTowerSparseNN (
545
390
tables = tables , weighted_tables = weighted_tables
546
391
)
@@ -669,26 +514,6 @@ def test_rw_sharding(self) -> None:
669
514
EXPECTED_RW_SHARD_STORAGE [i ],
670
515
)
671
516
672
- def test_virtual_table_rw_sharding_with_buckets (self ) -> None :
673
- sharding_options = self .virtual_table_enumerator .enumerate (
674
- self .model_with_buckets ,
675
- [cast (ModuleSharder [torch .nn .Module ], VirtualTableRWSharder ())],
676
- )
677
- for i , sharding_option in enumerate (sharding_options ):
678
- self .assertEqual (sharding_option .sharding_type , ShardingType .ROW_WISE .value )
679
- self .assertEqual (
680
- [shard .size for shard in sharding_option .shards ],
681
- EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS [i ],
682
- )
683
- self .assertEqual (
684
- [shard .offset for shard in sharding_option .shards ],
685
- EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS [i ],
686
- )
687
- self .assertEqual (
688
- [shard .storage for shard in sharding_option .shards ],
689
- EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS [i ],
690
- )
691
-
692
517
def test_uvm_caching_rw_sharding (self ) -> None :
693
518
sharding_options = self .enumerator .enumerate (
694
519
self .model ,
@@ -710,26 +535,6 @@ def test_uvm_caching_rw_sharding(self) -> None:
710
535
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE [i ],
711
536
)
712
537
713
- def test_uvm_caching_rw_sharding_with_buckets (self ) -> None :
714
- sharding_options = self .enumerator .enumerate (
715
- self .model_with_buckets ,
716
- [cast (ModuleSharder [torch .nn .Module ], UVMCachingRWSharder ())],
717
- )
718
- for i , sharding_option in enumerate (sharding_options ):
719
- self .assertEqual (sharding_option .sharding_type , ShardingType .ROW_WISE .value )
720
- self .assertEqual (
721
- [shard .size for shard in sharding_option .shards ],
722
- EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS [i ],
723
- )
724
- self .assertEqual (
725
- [shard .offset for shard in sharding_option .shards ],
726
- EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS [i ],
727
- )
728
- self .assertEqual (
729
- [shard .storage for shard in sharding_option .shards ],
730
- EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS [i ],
731
- )
732
-
733
538
def test_twrw_sharding (self ) -> None :
734
539
sharding_options = self .enumerator .enumerate (
735
540
self .model , [cast (ModuleSharder [torch .nn .Module ], TWRWSharder ())]
0 commit comments