Skip to content

Commit 392be3b

Browse files
yao-matrixydshieh
andauthored
fix test_working_of_tp failure of accelerate ut (#39828)
Signed-off-by: Yao, Matrix <[email protected]> Co-authored-by: Yih-Dar <[email protected]>
1 parent cc5de36 commit 392be3b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/integrations/tensor_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ def shard_and_distribute_module(
10191019
"""
10201020
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
10211021
tp_plan = model._tp_plan or {}
1022-
tp_plan.update(getattr(type(model), "_tp_plan", {}))
1022+
tp_plan.update(getattr(type(model), "_tp_plan", None) or {})
10231023
module_to_tp = model.get_submodule(param_name) # TODO: can i loop over modules?
10241024
rank = int(rank)
10251025
current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
@@ -1085,7 +1085,7 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
10851085

10861086
def distribute_model(model, distributed_config, device_mesh, tp_size):
10871087
_plan = "_tp_plan"
1088-
tp_plan = getattr(model, "_tp_plan", {}).copy()
1088+
tp_plan = (getattr(model, "_tp_plan", None) or {}).copy()
10891089
model._tp_plan = getattr(model.config, "base_model_tp_plan").copy()
10901090
model._tp_plan.update(tp_plan)
10911091
model._tp_size = tp_size

0 commit comments

Comments
 (0)