Skip to content

Commit e23cecf

Browse files
committed
Manual Resharding given Manifold Paths
Differential Revision: D82241141
1 parent eac316e commit e23cecf

File tree

2 files changed

+4
-174
lines changed

2 files changed

+4
-174
lines changed

torchrec/distributed/benchmark/benchmark_resharding_handler.py

Lines changed: 0 additions & 169 deletions
This file was deleted.

torchrec/distributed/model_parallel.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,10 @@ def __init__(
258258
device = torch.device("cpu")
259259
self.device: torch.device = device
260260

261-
if sharders is None:
262-
sharders = get_default_sharders()
261+
self.sharders = get_default_sharders() if sharders is None else sharders
263262

264263
self._sharder_map: Dict[Type[nn.Module], ModuleSharder[nn.Module]] = {
265-
sharder.module_type: sharder for sharder in sharders
264+
sharder.module_type: sharder for sharder in self.sharders
266265
}
267266

268267
if data_parallel_wrapper is None:
@@ -279,9 +278,9 @@ def __init__(
279278
)
280279
pg = self._env.process_group
281280
if pg is not None:
282-
plan = planner.collective_plan(module, sharders, pg)
281+
plan = planner.collective_plan(module, self.sharders, pg)
283282
else:
284-
plan = planner.plan(module, sharders)
283+
plan = planner.plan(module, self.sharders) # pyre-ignore
285284
self._plan: ShardingPlan = plan
286285
self._dmp_wrapped_module: nn.Module = self._init_dmp(module)
287286
self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module)

0 commit comments

Comments
 (0)