Skip to content

Commit 4673c16

Browse files
isururanawakameta-codesync[bot]
authored andcommitted
update num_poolings attribute to pass to the sharding options. (#3441)
Summary: Pull Request resolved: #3441 io_sizes, output_sizes, input_sizes calculations depend on num_poolings. Update the num_poolings to feed from manifold planner configs to sharding options recalculates correct options. Reviewed By: mserturk Differential Revision: D84111173 fbshipit-source-id: 0403396502dbc6015038de7ffe275af55f2130a8
1 parent 6de403e commit 4673c16

File tree

3 files changed

+30
-33
lines changed

3 files changed

+30
-33
lines changed

torchrec/distributed/planner/shard_estimators.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
HALF_BLOCK_PENALTY,
2525
kernel_bw_lookup,
2626
KV_CACHING_RATIO,
27+
NUM_POOLINGS,
2728
QUARTER_BLOCK_PENALTY,
2829
UVM_CACHING_RATIO,
2930
WEIGHTED_KERNEL_MULTIPLIER,
@@ -123,13 +124,7 @@ def estimate(
123124
else None
124125
)
125126

126-
num_poolings = (
127-
cast(List[float], self._constraints[sharding_option.name].num_poolings)
128-
if self._constraints
129-
and self._constraints.get(sharding_option.name)
130-
and self._constraints[sharding_option.name].num_poolings
131-
else [1.0] * sharding_option.num_inputs
132-
)
127+
num_poolings = get_num_poolings(self._constraints, sharding_option)
133128
batch_sizes = (
134129
cast(List[int], self._constraints[sharding_option.name].batch_sizes)
135130
if self._constraints
@@ -1008,11 +1003,7 @@ def estimate(
10081003
if self._constraints
10091004
else None
10101005
)
1011-
num_poolings = (
1012-
constraints.num_poolings
1013-
if constraints and constraints.num_poolings
1014-
else [1.0] * sharding_option.num_inputs
1015-
)
1006+
num_poolings = get_num_poolings(self._constraints, sharding_option)
10161007
assert len(num_poolings) == sharding_option.num_inputs
10171008
batch_sizes = (
10181009
constraints.batch_sizes
@@ -1313,6 +1304,25 @@ def _is_table_cached(
13131304
return False
13141305

13151306

1307+
def get_num_poolings(
1308+
constraints: Optional[Dict[str, ParameterConstraints]], so: ShardingOption
1309+
) -> List[float]:
1310+
# first priority is given for sharding_option.num_poolings,
1311+
# otherwise Manifold planner configs will be overwritten by parameter constraints
1312+
# default path will use constraints
1313+
if so.num_poolings is not None:
1314+
num_poolings = so.num_poolings
1315+
if len(so.input_lengths) == len(num_poolings):
1316+
return num_poolings
1317+
1318+
# Second priority: use constraint-based num_poolings
1319+
if constraints and constraints.get(so.name) and constraints[so.name].num_poolings:
1320+
return cast(List[float], constraints[so.name].num_poolings)
1321+
1322+
# Fallback: use default NUM_POOLINGS constant
1323+
return [NUM_POOLINGS] * len(so.input_lengths)
1324+
1325+
13161326
def _calculate_shard_io_sizes(
13171327
sharding_type: str,
13181328
batch_sizes: List[int],

torchrec/distributed/planner/stats.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929

3030
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
3131
from torchrec.distributed.planner.constants import BIGINT_DTYPE, NUM_POOLINGS
32-
from torchrec.distributed.planner.shard_estimators import _calculate_shard_io_sizes
32+
from torchrec.distributed.planner.shard_estimators import (
33+
_calculate_shard_io_sizes,
34+
get_num_poolings,
35+
)
3336
from torchrec.distributed.planner.storage_reservations import (
3437
FixedPercentageStorageReservation,
3538
HeuristicalStorageReservation,
@@ -361,13 +364,7 @@ def _get_shard_stats(
361364
assert shard.ranks
362365
ranks = shard.ranks
363366

364-
num_poolings = (
365-
cast(List[float], constraints[sharding_option.name].num_poolings)
366-
if constraints
367-
and constraints.get(sharding_option.name)
368-
and constraints[sharding_option.name].num_poolings
369-
else [1.0] * sharding_option.num_inputs
370-
)
367+
num_poolings = get_num_poolings(constraints, sharding_option)
371368
batch_sizes = (
372369
cast(List[int], constraints[sharding_option.name].batch_sizes)
373370
if constraints
@@ -761,18 +758,6 @@ def _get_embedding_dim(so: ShardingOption) -> str:
761758
)
762759
return embedding_dim
763760

764-
def _get_num_poolings(
765-
constraints: Optional[Dict[str, ParameterConstraints]], so: ShardingOption
766-
) -> List[float]:
767-
num_poolings = (
768-
cast(List[float], constraints[so.name].num_poolings)
769-
if constraints
770-
and constraints.get(so.name)
771-
and constraints[so.name].num_poolings
772-
else [NUM_POOLINGS] * len(so.input_lengths)
773-
)
774-
return num_poolings
775-
776761
def _get_cache_load_factor(
777762
sharder: Optional[ModuleSharder[nn.Module]], so: ShardingOption
778763
) -> str:
@@ -865,7 +850,7 @@ def _get_cache_load_factor(
865850
shard_storages = _format_storage_breakdown(so_storage)
866851

867852
pooling_factor = str(round(sum(so.input_lengths), 3))
868-
num_poolings = _get_num_poolings(constraints, so)
853+
num_poolings = get_num_poolings(constraints, so)
869854
num_indices = str(
870855
round(sum(x * y for x, y in zip(so.input_lengths, num_poolings)), 3)
871856
)

torchrec/distributed/planner/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ def __init__(
532532
feature_names: Optional[List[str]] = None,
533533
output_dtype: Optional[DataType] = None,
534534
key_value_params: Optional[KeyValueParams] = None,
535+
num_poolings: Optional[List[float]] = None,
535536
) -> None:
536537
self.name = name
537538
self._tensor = tensor
@@ -554,6 +555,7 @@ def __init__(
554555
self.feature_names: Optional[List[str]] = feature_names
555556
self.output_dtype: Optional[DataType] = output_dtype
556557
self.key_value_params: Optional[KeyValueParams] = key_value_params
558+
self.num_poolings: Optional[List[float]] = num_poolings
557559

558560
@property
559561
def tensor(self) -> torch.Tensor:

0 commit comments

Comments
 (0)