|
1 | 1 | import operator |
2 | | -from functools import reduce |
3 | 2 | import warnings |
| 3 | +from copy import deepcopy |
| 4 | +from functools import reduce |
| 5 | +from typing import Dict, List |
| 6 | + |
4 | 7 | import torch |
| 8 | + |
| 9 | +from colossalai.auto_parallel.tensor_shard.deprecated._utils import ( |
| 10 | + enumerate_all_possible_1d_sharding, |
| 11 | + enumerate_all_possible_2d_sharding, |
| 12 | + exception_handler, |
| 13 | +) |
5 | 14 | from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector |
6 | | -from .operator_handler import OperatorHandler |
7 | 15 | from colossalai.tensor.shape_consistency import ShapeConsistencyManager |
8 | 16 | from colossalai.tensor.sharding_spec import ShardingSpec |
9 | | -from copy import deepcopy |
10 | | -from typing import Dict, List |
11 | | -from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding |
| 17 | + |
| 18 | +from .operator_handler import OperatorHandler |
12 | 19 |
|
13 | 20 | __all__ = ['WhereHandler'] |
14 | 21 |
|
@@ -94,7 +101,7 @@ def _generate_resharding_costs(self, sharding_specs): |
94 | 101 | # compute the resharding cost |
95 | 102 | _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( |
96 | 103 | input_sharding_spec, input_spec) |
97 | | - |
| 104 | + total_resharding_cost = total_resharding_cost['total'] |
98 | 105 | # we need multiply the size of elem dtype to get correct communication cost |
99 | 106 | resharding_cost = total_resharding_cost * size_per_elem_bytes |
100 | 107 | resharding_costs[input_node].append(resharding_cost) |
|
0 commit comments