Skip to content

Commit 5fdcd08

Browse files
nipung90meta-codesync[bot]
authored andcommitted
Back out "Allow the ability for uneven row wise sharding based on number of buckets for zch" (#3446)
Summary: Pull Request resolved: #3446 Original commit changeset: 5d7e0d55bb83 Original Phabricator Diff: D79659949 Reviewed By: emlin Differential Revision: D84267146 fbshipit-source-id: 9e9f00073720d4621aab55ebb2b476dc01b14127
1 parent f44611b commit 5fdcd08

File tree

8 files changed

+28
-531
lines changed

8 files changed

+28
-531
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -456,23 +456,16 @@ def _get_sharded_local_buckets_for_zero_collision(
456456

457457
for table in embedding_tables:
458458
total_num_buckets = none_throws(table.total_num_buckets)
459+
assert (
460+
total_num_buckets % world_size == 0
461+
), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}"
459462
assert (
460463
table.total_num_buckets
461464
and table.num_embeddings % table.total_num_buckets == 0
462465
), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'"
463-
extra_local_buckets = int(local_rank < (total_num_buckets % world_size))
464-
extra_bucket_padding = (
465-
(total_num_buckets % world_size)
466-
if local_rank >= (total_num_buckets % world_size)
467-
else 0
468-
)
469-
bucket_offset_start = (
470-
total_num_buckets // world_size + extra_local_buckets
471-
) * local_rank + extra_bucket_padding
466+
bucket_offset_start = total_num_buckets // world_size * local_rank
472467
bucket_offset_end = min(
473-
total_num_buckets,
474-
(total_num_buckets // world_size + extra_local_buckets) * (local_rank + 1)
475-
+ extra_bucket_padding,
468+
total_num_buckets, total_num_buckets // world_size * (local_rank + 1)
476469
)
477470
bucket_size = (
478471
table.num_embeddings + total_num_buckets - 1

torchrec/distributed/embedding_kernel.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,9 @@ def create_virtual_table_global_metadata(
9999
# Otherwise it will only set correct size on current rank and
100100
# virtual PMT will trigger recalc for the correct global size/offset.
101101
# NOTE this currently only works for row-wise sharding
102-
my_rank_shard_size = metadata.shards_metadata[my_rank].shard_sizes[0]
103102
for rank, shard_metadata in enumerate(metadata.shards_metadata):
104103
if use_param_size_as_rows: # respect the param size and treat it as rows
105-
# The param size only has the information for my_rank. In order to
106-
# correctly calculate the size for other ranks, we need to use the current
107-
# rank's shard size compared to the shard size of my_rank.
108-
curr_rank_rows = (
109-
param.size()[0] # pyre-ignore[16]
110-
* metadata.shards_metadata[rank].shard_sizes[0]
111-
) // my_rank_shard_size
104+
curr_rank_rows = param.size()[0] # pyre-ignore[16]
112105
else:
113106
curr_rank_rows = (
114107
weight_count_per_rank[rank] if weight_count_per_rank is not None else 1

torchrec/distributed/planner/enumerators.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@
3838
ShardingType,
3939
)
4040
from torchrec.modules.embedding_configs import DataType
41-
from torchrec.modules.embedding_modules import (
42-
EmbeddingBagCollection,
43-
EmbeddingCollection,
44-
)
4541
from torchrec.modules.embedding_tower import EmbeddingTower, EmbeddingTowerCollection
4642

4743

@@ -182,7 +178,7 @@ def enumerate(
182178
# skip for other device groups
183179
if device_group and device_group != self._compute_device:
184180
continue
185-
num_buckets = self._get_num_buckets(name, child_module)
181+
186182
sharding_options_per_table: List[ShardingOption] = []
187183

188184
for sharding_type in self._filter_sharding_types(
@@ -204,7 +200,6 @@ def enumerate(
204200
sharding_type=sharding_type,
205201
col_wise_shard_dim=col_wise_shard_dim,
206202
device_memory_sizes=self._device_memory_sizes,
207-
num_buckets=num_buckets,
208203
)
209204
except ZeroDivisionError as e:
210205
# Re-raise with additional context about the table and module
@@ -269,33 +264,6 @@ def enumerate(
269264
self._last_stored_search_space = copy.deepcopy(sharding_options)
270265
return sharding_options
271266

272-
def _get_num_buckets(self, parameter: str, module: nn.Module) -> Optional[int]:
273-
"""
274-
Get the number of buckets for each embedding table.
275-
276-
Args:
277-
parameter (str): name of the embedding table.
278-
module (nn.Module): module to be sharded.
279-
280-
Returns:
281-
Optional[int]: Number of buckets for the table, or None if module is not EmbeddingBagCollection or table not found.
282-
"""
283-
# If module is not of type EmbeddingBagCollection, return None
284-
if isinstance(module, EmbeddingBagCollection):
285-
embedding_configs = module.embedding_bag_configs()
286-
elif isinstance(module, EmbeddingCollection):
287-
embedding_configs = module.embedding_configs()
288-
else:
289-
return None
290-
291-
# Find the embedding config for the table with the same name as parameter input
292-
for config in embedding_configs:
293-
if config.name == parameter and config.use_virtual_table:
294-
return config.total_num_buckets
295-
296-
# If table with matching name not found, return None
297-
return None
298-
299267
@property
300268
def last_stored_search_space(self) -> Optional[List[ShardingOption]]:
301269
# NOTE: This is the last search space stored by enumerate(...), do not use

torchrec/distributed/planner/tests/test_enumerators.py

Lines changed: 1 addition & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818
EmbeddingTowerSharder,
1919
)
2020
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
21-
from torchrec.distributed.embeddingbag import (
22-
EmbeddingBagCollection,
23-
EmbeddingBagCollectionSharder,
24-
)
21+
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
2522
from torchrec.distributed.mc_embeddingbag import (
2623
ManagedCollisionEmbeddingBagCollectionSharder,
2724
)
@@ -48,27 +45,13 @@
4845
[[17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [11, 80]],
4946
]
5047

51-
EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS = [
52-
[[20, 20], [20, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20]],
53-
[[22, 40], [22, 40], [11, 40], [11, 40], [11, 40], [11, 40], [11, 40], [11, 40]],
54-
[[24, 60], [24, 60], [12, 60], [12, 60], [12, 60], [12, 60], [12, 60], [12, 60]],
55-
[[26, 80], [26, 80], [13, 80], [13, 80], [13, 80], [13, 80], [13, 80], [13, 80]],
56-
]
57-
5848
EXPECTED_RW_SHARD_OFFSETS = [
5949
[[0, 0], [13, 0], [26, 0], [39, 0], [52, 0], [65, 0], [78, 0], [91, 0]],
6050
[[0, 0], [14, 0], [28, 0], [42, 0], [56, 0], [70, 0], [84, 0], [98, 0]],
6151
[[0, 0], [15, 0], [30, 0], [45, 0], [60, 0], [75, 0], [90, 0], [105, 0]],
6252
[[0, 0], [17, 0], [34, 0], [51, 0], [68, 0], [85, 0], [102, 0], [119, 0]],
6353
]
6454

65-
EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS = [
66-
[[0, 0], [20, 0], [40, 0], [50, 0], [60, 0], [70, 0], [80, 0], [90, 0]],
67-
[[0, 0], [22, 0], [44, 0], [55, 0], [66, 0], [77, 0], [88, 0], [99, 0]],
68-
[[0, 0], [24, 0], [48, 0], [60, 0], [72, 0], [84, 0], [96, 0], [108, 0]],
69-
[[0, 0], [26, 0], [52, 0], [65, 0], [78, 0], [91, 0], [104, 0], [117, 0]],
70-
]
71-
7255

7356
def get_expected_cache_aux_size(rows: int) -> int:
7457
# 0.2 is the hardcoded cache load factor assumed in this test
@@ -118,48 +101,6 @@ def get_expected_cache_aux_size(rows: int) -> int:
118101
],
119102
]
120103

121-
EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS = [
122-
[
123-
Storage(hbm=165888, ddr=0),
124-
Storage(hbm=165888, ddr=0),
125-
Storage(hbm=165888, ddr=0),
126-
Storage(hbm=165888, ddr=0),
127-
Storage(hbm=165888, ddr=0),
128-
Storage(hbm=165888, ddr=0),
129-
Storage(hbm=165888, ddr=0),
130-
Storage(hbm=165888, ddr=0),
131-
],
132-
[
133-
Storage(hbm=1001472, ddr=0),
134-
Storage(hbm=1001472, ddr=0),
135-
Storage(hbm=1001472, ddr=0),
136-
Storage(hbm=1001472, ddr=0),
137-
Storage(hbm=1001472, ddr=0),
138-
Storage(hbm=1001472, ddr=0),
139-
Storage(hbm=1001472, ddr=0),
140-
Storage(hbm=1001472, ddr=0),
141-
],
142-
[
143-
Storage(hbm=1003520, ddr=0),
144-
Storage(hbm=1003520, ddr=0),
145-
Storage(hbm=1003520, ddr=0),
146-
Storage(hbm=1003520, ddr=0),
147-
Storage(hbm=1003520, ddr=0),
148-
Storage(hbm=1003520, ddr=0),
149-
Storage(hbm=1003520, ddr=0),
150-
Storage(hbm=1003520, ddr=0),
151-
],
152-
[
153-
Storage(hbm=2648064, ddr=0),
154-
Storage(hbm=2648064, ddr=0),
155-
Storage(hbm=2648064, ddr=0),
156-
Storage(hbm=2648064, ddr=0),
157-
Storage(hbm=2648064, ddr=0),
158-
Storage(hbm=2648064, ddr=0),
159-
Storage(hbm=2648064, ddr=0),
160-
Storage(hbm=2648064, ddr=0),
161-
],
162-
]
163104

164105
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [
165106
[
@@ -204,48 +145,6 @@ def get_expected_cache_aux_size(rows: int) -> int:
204145
],
205146
]
206147

207-
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS = [
208-
[
209-
Storage(hbm=166352, ddr=1600),
210-
Storage(hbm=166352, ddr=1600),
211-
Storage(hbm=166120, ddr=800),
212-
Storage(hbm=166120, ddr=800),
213-
Storage(hbm=166120, ddr=800),
214-
Storage(hbm=166120, ddr=800),
215-
Storage(hbm=166120, ddr=800),
216-
Storage(hbm=166120, ddr=800),
217-
],
218-
[
219-
Storage(hbm=1002335, ddr=3520),
220-
Storage(hbm=1002335, ddr=3520),
221-
Storage(hbm=1001904, ddr=1760),
222-
Storage(hbm=1001904, ddr=1760),
223-
Storage(hbm=1001904, ddr=1760),
224-
Storage(hbm=1001904, ddr=1760),
225-
Storage(hbm=1001904, ddr=1760),
226-
Storage(hbm=1001904, ddr=1760),
227-
],
228-
[
229-
Storage(hbm=1004845, ddr=5760),
230-
Storage(hbm=1004845, ddr=5760),
231-
Storage(hbm=1004183, ddr=2880),
232-
Storage(hbm=1004183, ddr=2880),
233-
Storage(hbm=1004183, ddr=2880),
234-
Storage(hbm=1004183, ddr=2880),
235-
Storage(hbm=1004183, ddr=2880),
236-
Storage(hbm=1004183, ddr=2880),
237-
],
238-
[
239-
Storage(hbm=2649916, ddr=8320),
240-
Storage(hbm=2649916, ddr=8320),
241-
Storage(hbm=2648990, ddr=4160),
242-
Storage(hbm=2648990, ddr=4160),
243-
Storage(hbm=2648990, ddr=4160),
244-
Storage(hbm=2648990, ddr=4160),
245-
Storage(hbm=2648990, ddr=4160),
246-
Storage(hbm=2648990, ddr=4160),
247-
],
248-
]
249148

250149
EXPECTED_TWRW_SHARD_SIZES = [
251150
[[25, 20], [25, 20], [25, 20], [25, 20]],
@@ -349,16 +248,6 @@ def compute_kernels(
349248
return [EmbeddingComputeKernel.FUSED.value]
350249

351250

352-
class VirtualTableRWSharder(EmbeddingBagCollectionSharder):
353-
def sharding_types(self, compute_device_type: str) -> List[str]:
354-
return [ShardingType.ROW_WISE.value]
355-
356-
def compute_kernels(
357-
self, sharding_type: str, compute_device_type: str
358-
) -> List[str]:
359-
return [EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value]
360-
361-
362251
class UVMCachingRWSharder(EmbeddingBagCollectionSharder):
363252
def sharding_types(self, compute_device_type: str) -> List[str]:
364253
return [ShardingType.ROW_WISE.value]
@@ -468,27 +357,6 @@ def setUp(self) -> None:
468357
min_partition=40, pooling_factors=[2, 1, 3, 7]
469358
),
470359
}
471-
self._virtual_table_constraints = {
472-
"table_0": ParameterConstraints(
473-
min_partition=20,
474-
compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value],
475-
),
476-
"table_1": ParameterConstraints(
477-
min_partition=20,
478-
pooling_factors=[1, 3, 5],
479-
compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value],
480-
),
481-
"table_2": ParameterConstraints(
482-
min_partition=20,
483-
pooling_factors=[8, 2],
484-
compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value],
485-
),
486-
"table_3": ParameterConstraints(
487-
min_partition=40,
488-
pooling_factors=[2, 1, 3, 7],
489-
compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value],
490-
),
491-
}
492360
self.num_tables = 4
493361
tables = [
494362
EmbeddingBagConfig(
@@ -499,17 +367,6 @@ def setUp(self) -> None:
499367
)
500368
for i in range(self.num_tables)
501369
]
502-
tables_with_buckets = [
503-
EmbeddingBagConfig(
504-
num_embeddings=100 + i * 10,
505-
embedding_dim=20 + i * 20,
506-
name="table_" + str(i),
507-
feature_names=["feature_" + str(i)],
508-
total_num_buckets=10,
509-
use_virtual_table=True,
510-
)
511-
for i in range(self.num_tables)
512-
]
513370
weighted_tables = [
514371
EmbeddingBagConfig(
515372
num_embeddings=(i + 1) * 10,
@@ -520,9 +377,6 @@ def setUp(self) -> None:
520377
for i in range(4)
521378
]
522379
self.model = TestSparseNN(tables=tables, weighted_tables=[])
523-
self.model_with_buckets = EmbeddingBagCollection(
524-
tables=tables_with_buckets,
525-
)
526380
self.enumerator = EmbeddingEnumerator(
527381
topology=Topology(
528382
world_size=self.world_size,
@@ -532,15 +386,6 @@ def setUp(self) -> None:
532386
batch_size=self.batch_size,
533387
constraints=self.constraints,
534388
)
535-
self.virtual_table_enumerator = EmbeddingEnumerator(
536-
topology=Topology(
537-
world_size=self.world_size,
538-
compute_device=self.compute_device,
539-
local_world_size=self.local_world_size,
540-
),
541-
batch_size=self.batch_size,
542-
constraints=self._virtual_table_constraints,
543-
)
544389
self.tower_model = TestTowerSparseNN(
545390
tables=tables, weighted_tables=weighted_tables
546391
)
@@ -669,26 +514,6 @@ def test_rw_sharding(self) -> None:
669514
EXPECTED_RW_SHARD_STORAGE[i],
670515
)
671516

672-
def test_virtual_table_rw_sharding_with_buckets(self) -> None:
673-
sharding_options = self.virtual_table_enumerator.enumerate(
674-
self.model_with_buckets,
675-
[cast(ModuleSharder[torch.nn.Module], VirtualTableRWSharder())],
676-
)
677-
for i, sharding_option in enumerate(sharding_options):
678-
self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value)
679-
self.assertEqual(
680-
[shard.size for shard in sharding_option.shards],
681-
EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS[i],
682-
)
683-
self.assertEqual(
684-
[shard.offset for shard in sharding_option.shards],
685-
EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS[i],
686-
)
687-
self.assertEqual(
688-
[shard.storage for shard in sharding_option.shards],
689-
EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS[i],
690-
)
691-
692517
def test_uvm_caching_rw_sharding(self) -> None:
693518
sharding_options = self.enumerator.enumerate(
694519
self.model,
@@ -710,26 +535,6 @@ def test_uvm_caching_rw_sharding(self) -> None:
710535
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE[i],
711536
)
712537

713-
def test_uvm_caching_rw_sharding_with_buckets(self) -> None:
714-
sharding_options = self.enumerator.enumerate(
715-
self.model_with_buckets,
716-
[cast(ModuleSharder[torch.nn.Module], UVMCachingRWSharder())],
717-
)
718-
for i, sharding_option in enumerate(sharding_options):
719-
self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value)
720-
self.assertEqual(
721-
[shard.size for shard in sharding_option.shards],
722-
EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS[i],
723-
)
724-
self.assertEqual(
725-
[shard.offset for shard in sharding_option.shards],
726-
EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS[i],
727-
)
728-
self.assertEqual(
729-
[shard.storage for shard in sharding_option.shards],
730-
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS[i],
731-
)
732-
733538
def test_twrw_sharding(self) -> None:
734539
sharding_options = self.enumerator.enumerate(
735540
self.model, [cast(ModuleSharder[torch.nn.Module], TWRWSharder())]

0 commit comments

Comments
 (0)