Skip to content

Commit c908517

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
Add row based sharding support for FeaturedProcessedEBC (#3281)
Summary: Pull Request resolved: #3281 In this diff we introduce row based sharding (TWRW, RW, GRID) type support for feature processors. Previously, feature processors did not support row based sharding since feature processors are data parallel. This means by splitting up the input for row based shards the accessed feature processor weights were in correct. In column/data sharding based approaches, the data is duplicated ensuring the correct weight is accessed across ranks. The indices/buckets are calculated post input split/distribution, to make it compatible with row based sharding we calculate this pre input split/distribution. This couples the train pipeline and feature processors. For each feature, we preprocess the input and place the calculated indices in KJT.weights, this propagates the indices correctly and indexs into the right weight to use for the final step in the feature processing. This applies in both pipelined and non pipelined situations - the input modification is done either at the pipelined forward call or in the input dist of the FPEBC. This is determined by the pipelining flag set through rewrite_model in train pipeline. Reviewed By: che-sh Differential Revision: D69125073 fbshipit-source-id: 0cc5bd49f9bf84476904abe1b9473048d212b0e2
1 parent 3f94743 commit c908517

File tree

9 files changed

+330
-26
lines changed

9 files changed

+330
-26
lines changed

torchrec/distributed/fp_embeddingbag.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,18 @@
88
# pyre-strict
99

1010
from functools import partial
11-
from typing import Any, Dict, Iterator, List, Optional, Type, Union
11+
from typing import (
12+
Any,
13+
Dict,
14+
Iterator,
15+
List,
16+
Mapping,
17+
Optional,
18+
Tuple,
19+
Type,
20+
TypeVar,
21+
Union,
22+
)
1223

1324
import torch
1425
from torch import nn
@@ -31,14 +42,20 @@
3142
ShardingEnv,
3243
ShardingType,
3344
)
34-
from torchrec.distributed.utils import append_prefix, init_parameters
45+
from torchrec.distributed.utils import (
46+
append_prefix,
47+
init_parameters,
48+
modify_input_for_feature_processor,
49+
)
3550
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
3651
from torchrec.modules.fp_embedding_modules import (
3752
apply_feature_processors_to_kjt,
3853
FeatureProcessedEmbeddingBagCollection,
3954
)
4055
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
4156

57+
_T = TypeVar("_T")
58+
4259

4360
def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor:
4461
kt._values.add_(no_op_tensor)
@@ -74,6 +91,16 @@ def __init__(
7491
)
7592
)
7693

94+
self._row_wise_sharded: bool = False
95+
for param_sharding in table_name_to_parameter_sharding.values():
96+
if param_sharding.sharding_type in [
97+
ShardingType.ROW_WISE.value,
98+
ShardingType.TABLE_ROW_WISE.value,
99+
ShardingType.GRID_SHARD.value,
100+
]:
101+
self._row_wise_sharded = True
102+
break
103+
77104
self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups
78105

79106
self._is_collection: bool = False
@@ -96,6 +123,11 @@ def __init__(
96123
def input_dist(
97124
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
98125
) -> Awaitable[Awaitable[KJTList]]:
126+
if not self.is_pipelined and self._row_wise_sharded:
127+
# transform input to support row based sharding when not pipelined
128+
modify_input_for_feature_processor(
129+
features, self._feature_processors, self._is_collection
130+
)
99131
return self._embedding_bag_collection.input_dist(ctx, features)
100132

101133
def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
@@ -105,10 +137,7 @@ def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
105137
kjt_list.append(self._feature_processors(features))
106138
else:
107139
kjt_list.append(
108-
apply_feature_processors_to_kjt(
109-
features,
110-
self._feature_processors,
111-
)
140+
apply_feature_processors_to_kjt(features, self._feature_processors)
112141
)
113142
return KJTList(kjt_list)
114143

@@ -117,7 +146,6 @@ def compute(
117146
ctx: EmbeddingBagCollectionContext,
118147
dist_input: KJTList,
119148
) -> List[torch.Tensor]:
120-
121149
fp_features = self.apply_feature_processors_to_kjt_list(dist_input)
122150
return self._embedding_bag_collection.compute(ctx, fp_features)
123151

@@ -163,6 +191,18 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
163191
if "_embedding_bag_collection" in fqn:
164192
yield append_prefix(prefix, fqn)
165193

194+
def preprocess_input(
195+
self, args: List[_T], kwargs: Mapping[str, _T]
196+
) -> Tuple[List[_T], Mapping[str, _T]]:
197+
for x in args + list(kwargs.values()):
198+
if isinstance(x, KeyedJaggedTensor):
199+
modify_input_for_feature_processor(
200+
features=x,
201+
feature_processors=self._feature_processors,
202+
is_collection=self._is_collection,
203+
)
204+
return args, kwargs
205+
166206

167207
class FeatureProcessedEmbeddingBagCollectionSharder(
168208
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]
@@ -188,7 +228,6 @@ def shard(
188228
device: Optional[torch.device] = None,
189229
module_fqn: Optional[str] = None,
190230
) -> ShardedFeatureProcessedEmbeddingBagCollection:
191-
192231
if device is None:
193232
device = torch.device("cuda")
194233

@@ -225,12 +264,14 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
225264
if compute_device_type in {"mtia"}:
226265
return [ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value]
227266

228-
# No row wise because position weighted FP and RW don't play well together.
229267
types = [
230268
ShardingType.DATA_PARALLEL.value,
231269
ShardingType.TABLE_WISE.value,
232270
ShardingType.COLUMN_WISE.value,
233271
ShardingType.TABLE_COLUMN_WISE.value,
272+
ShardingType.TABLE_ROW_WISE.value,
273+
ShardingType.ROW_WISE.value,
274+
ShardingType.GRID_SHARD.value,
234275
]
235276

236277
return types

torchrec/distributed/tests/test_fp_embeddingbag.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ class ShardedEmbeddingBagCollectionParallelTest(MultiProcessTestBase):
231231
def test_sharding_ebc(
232232
self, set_gradient_division: bool, use_dmp: bool, use_fp_collection: bool
233233
) -> None:
234-
235234
import hypothesis
236235

237236
# don't need to test entire matrix

torchrec/distributed/tests/test_fp_embeddingbag_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,12 @@ def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]:
8686
pred = torch.cat(
8787
[
8888
fp_ebc_out[key]
89-
for key in ["feature_0", "feature_1", "feature_2", "feature_3"]
89+
for key in [
90+
"feature_0",
91+
"feature_1",
92+
"feature_2",
93+
"feature_3",
94+
]
9095
],
9196
dim=1,
9297
)

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from torch._dynamo.utils import counters
2424
from torch.fx._symbolic_trace import is_fx_tracing
2525
from torchrec.distributed import DistributedModelParallel
26-
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
26+
from torchrec.distributed.embedding_types import (
27+
EmbeddingComputeKernel,
28+
EmbeddingTableConfig,
29+
)
2730
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
2831
from torchrec.distributed.fp_embeddingbag import (
2932
FeatureProcessedEmbeddingBagCollectionSharder,
@@ -32,8 +35,13 @@
3235
from torchrec.distributed.model_parallel import DMPCollection
3336
from torchrec.distributed.sharding_plan import (
3437
construct_module_sharding_plan,
38+
row_wise,
3539
table_wise,
3640
)
41+
from torchrec.distributed.test_utils.multi_process import (
42+
MultiProcessContext,
43+
MultiProcessTestBase,
44+
)
3745
from torchrec.distributed.test_utils.test_model import (
3846
ModelInput,
3947
TestEBCSharder,
@@ -342,6 +350,161 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None:
342350
torch.testing.assert_close(pred_gpu.cpu(), pred)
343351

344352

353+
def fp_ebc(
354+
rank: int,
355+
world_size: int,
356+
tables: List[EmbeddingTableConfig],
357+
weighted_tables: List[EmbeddingTableConfig],
358+
data: List[Tuple[ModelInput, List[ModelInput]]],
359+
backend: str = "nccl",
360+
local_size: Optional[int] = None,
361+
) -> None:
362+
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
363+
assert ctx.pg is not None
364+
sharder = cast(
365+
ModuleSharder[nn.Module],
366+
FeatureProcessedEmbeddingBagCollectionSharder(),
367+
)
368+
369+
class DummyWrapper(nn.Module):
370+
def __init__(self, sparse_arch):
371+
super().__init__()
372+
self.m = sparse_arch
373+
374+
def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
375+
return self.m(model_input.idlist_features)
376+
377+
max_feature_lengths = [10, 10, 12, 12]
378+
sparse_arch = DummyWrapper(
379+
create_module_and_freeze(
380+
tables=tables, # pyre-ignore[6]
381+
device=ctx.device,
382+
use_fp_collection=False,
383+
max_feature_lengths=max_feature_lengths,
384+
)
385+
)
386+
387+
# compute_kernel = EmbeddingComputeKernel.FUSED.value
388+
module_sharding_plan = construct_module_sharding_plan(
389+
sparse_arch.m._fp_ebc,
390+
per_param_sharding={
391+
"table_0": row_wise(),
392+
"table_1": row_wise(),
393+
"table_2": row_wise(),
394+
"table_3": row_wise(),
395+
},
396+
world_size=2,
397+
device_type=ctx.device.type,
398+
sharder=sharder,
399+
)
400+
sharded_sparse_arch_pipeline = DistributedModelParallel(
401+
module=copy.deepcopy(sparse_arch),
402+
plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}),
403+
env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6]
404+
sharders=[sharder],
405+
device=ctx.device,
406+
)
407+
sharded_sparse_arch_no_pipeline = DistributedModelParallel(
408+
module=copy.deepcopy(sparse_arch),
409+
plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}),
410+
env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6]
411+
sharders=[sharder],
412+
device=ctx.device,
413+
)
414+
415+
batches = []
416+
for d in data:
417+
batches.append(d[1][ctx.rank].to(ctx.device))
418+
dataloader = iter(batches)
419+
420+
optimizer_no_pipeline = optim.SGD(
421+
sharded_sparse_arch_no_pipeline.parameters(), lr=0.1
422+
)
423+
optimizer_pipeline = optim.SGD(
424+
sharded_sparse_arch_pipeline.parameters(), lr=0.1
425+
)
426+
427+
pipeline = TrainPipelineSparseDist(
428+
sharded_sparse_arch_pipeline,
429+
optimizer_pipeline,
430+
ctx.device,
431+
)
432+
433+
for batch in batches[:-2]:
434+
batch = batch.to(ctx.device)
435+
optimizer_no_pipeline.zero_grad()
436+
loss, pred = sharded_sparse_arch_no_pipeline(batch)
437+
loss.backward()
438+
optimizer_no_pipeline.step()
439+
440+
pred_pipeline = pipeline.progress(dataloader)
441+
torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu())
442+
443+
444+
class TrainPipelineGPUTest(MultiProcessTestBase):
445+
def setUp(self, backend: str = "nccl") -> None:
446+
super().setUp()
447+
448+
self.pipeline_class = TrainPipelineSparseDist
449+
num_features = 4
450+
num_weighted_features = 4
451+
self.tables = [
452+
EmbeddingBagConfig(
453+
num_embeddings=(i + 1) * 100,
454+
embedding_dim=(i + 1) * 4,
455+
name="table_" + str(i),
456+
feature_names=["feature_" + str(i)],
457+
)
458+
for i in range(num_features)
459+
]
460+
self.weighted_tables = [
461+
EmbeddingBagConfig(
462+
num_embeddings=(i + 1) * 100,
463+
embedding_dim=(i + 1) * 4,
464+
name="weighted_table_" + str(i),
465+
feature_names=["weighted_feature_" + str(i)],
466+
)
467+
for i in range(num_weighted_features)
468+
]
469+
470+
self.backend = backend
471+
if torch.cuda.is_available():
472+
self.device = torch.device("cuda")
473+
else:
474+
self.device = torch.device("cpu")
475+
476+
if self.backend == "nccl" and self.device == torch.device("cpu"):
477+
self.skipTest("NCCL not supported on CPUs.")
478+
479+
def _generate_data(
480+
self,
481+
num_batches: int = 5,
482+
batch_size: int = 1,
483+
max_feature_lengths: Optional[List[int]] = None,
484+
) -> List[Tuple[ModelInput, List[ModelInput]]]:
485+
return [
486+
ModelInput.generate(
487+
tables=self.tables,
488+
weighted_tables=self.weighted_tables,
489+
batch_size=batch_size,
490+
world_size=2,
491+
num_float_features=10,
492+
max_feature_lengths=max_feature_lengths,
493+
)
494+
for i in range(num_batches)
495+
]
496+
497+
def test_fp_ebc_rw(self) -> None:
498+
data = self._generate_data(max_feature_lengths=[10, 10, 12, 12])
499+
self._run_multi_process_test(
500+
callable=fp_ebc,
501+
world_size=2,
502+
tables=self.tables,
503+
weighted_tables=self.weighted_tables,
504+
data=data,
505+
)
506+
507+
345508
class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase):
346509
# pyre-fixme[56]: Pyre was not able to infer the type of argument
347510
@unittest.skipIf(

torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def setUp(self) -> None:
4040
self.pg = init_distributed_single_host(backend=backend, rank=0, world_size=1)
4141

4242
num_features = 4
43-
num_weighted_features = 2
43+
num_weighted_features = 4
4444
self.tables = [
4545
EmbeddingBagConfig(
4646
num_embeddings=(i + 1) * 100,

torchrec/distributed/train_pipeline/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def _start_data_dist(
147147
# and this info was done in the _rewrite_model by tracing the
148148
# entire model to get the arg_info_list
149149
args, kwargs = forward.args.build_args_kwargs(batch)
150+
args, kwargs = module.preprocess_input(args, kwargs)
150151

151152
# Start input distribution.
152153
module_ctx = module.create_context()
@@ -379,6 +380,8 @@ def _rewrite_model( # noqa C901
379380
logger.info(f"Module '{node.target}' will be pipelined")
380381
child = sharded_modules[node.target]
381382
original_forwards.append(child.forward)
383+
# Set pipelining flag on the child module
384+
child.is_pipelined = True
382385
# pyre-ignore[8] Incompatible attribute type
383386
child.forward = pipelined_forward(
384387
node.target,

0 commit comments

Comments
 (0)