Skip to content

Commit d373e67

Browse files
[hotfix] resharding cost issue (#1742)
1 parent 24e84eb commit d373e67

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
import operator
2-
from functools import reduce
32
import warnings
3+
from copy import deepcopy
4+
from functools import reduce
5+
from typing import Dict, List
6+
47
import torch
8+
9+
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
10+
enumerate_all_possible_1d_sharding,
11+
enumerate_all_possible_2d_sharding,
12+
exception_handler,
13+
)
514
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
6-
from .operator_handler import OperatorHandler
715
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
816
from colossalai.tensor.sharding_spec import ShardingSpec
9-
from copy import deepcopy
10-
from typing import Dict, List
11-
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
17+
18+
from .operator_handler import OperatorHandler
1219

1320
__all__ = ['WhereHandler']
1421

@@ -94,7 +101,7 @@ def _generate_resharding_costs(self, sharding_specs):
94101
# compute the resharding cost
95102
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
96103
input_sharding_spec, input_spec)
97-
104+
total_resharding_cost = total_resharding_cost['total']
98105
# we need multiply the size of elem dtype to get correct communication cost
99106
resharding_cost = total_resharding_cost * size_per_elem_bytes
100107
resharding_costs[input_node].append(resharding_cost)

0 commit comments

Comments
 (0)