Skip to content

Commit cfe2a9b

Browse files
authored
[autoparallel] memory estimation for shape consistency (#2144)
* [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallel * [fx] add conv metainfo class * [fx] restore profiler * [fx] restore meta profiler * [autoparallel] modify unit test * [fx] modify unit test * [autoparallel] add batchnorm metainfo class * [autoparallel] fix batchnorm unit test function declaration * [fx] restore profiler * [fx] add relu metainfo class * [fx] restore profiler * [autoparallel] modify metainfo input * [autoparallel] add pooling metainfo * [autoparallel] add F.linear metainfo generator * [autoparallel] add binary elementwise metainfo * [fx] recover profiler * [autoparallel] fix forward memory calculation * [autoparallel] modify constants.py * [autoparallel] remove redundant print * [autoparallel] add F.conv metainfo * [autoparallel] linear fix * [autoparallel] memory estimation for communication actions * [autoparallel] fix docstring * [autoparallel] fix variables name
1 parent b87496a commit cfe2a9b

File tree

2 files changed

+155
-1
lines changed

2 files changed

+155
-1
lines changed

colossalai/auto_parallel/tensor_shard/sharding_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from torch.fx.node import Node
88

9-
from colossalai.tensor.shape_consistency import CommSpec
9+
from colossalai.tensor.comm_spec import CommSpec
1010
from colossalai.tensor.sharding_spec import ShardingSpec
1111

1212
from .constants import (

colossalai/tensor/shape_consistency.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from dataclasses import dataclass
44
from typing import Dict, List, Tuple
55

6+
import numpy as np
67
import torch
78

9+
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
810
from colossalai.context.singleton_meta import SingletonMeta
911
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
1012
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator
@@ -403,6 +405,158 @@ def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_d
403405
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict))
404406
return valid_spec_dict
405407

408+
def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem:
409+
"""memory cost of the communication action sequence
410+
TODO: Currently we just consider tensor numel in the shape consistency manger,
411+
as the manager itself doesn't have the access to tensor dtype, we need to take
412+
it into consideration in memory estimation.
413+
414+
Args:
415+
comm_action_sequence (List[CommSpec]): list of communication actions
416+
417+
Returns:
418+
TrainCycleItem: memory (numel) cost of such comm_action_sequence
419+
"""
420+
421+
def compute_shape(sharding_spec: ShardingSpec):
422+
shape = sharding_spec.entire_shape
423+
for dim, shard in sharding_spec.dim_partition_dict.items():
424+
shape[dim] = shape[dim] // len(shard)
425+
return shape
426+
427+
def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
428+
"""analyze all_gather memory footprint
429+
all_gather will allocate memory for the output tensor, and there will be temp memory for
430+
all_gather operation, which is twice the size of output tensor
431+
432+
Args:
433+
comm_spec (CommSpec): input CommSpec
434+
discard_input (bool): whether to discard the input tensor
435+
alloc_numel (int): current allocated numel
436+
peak_numel (int): current peak numel
437+
"""
438+
input_shape = compute_shape(comm_spec.sharding_spec)
439+
input_numel = np.prod(input_shape)
440+
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
441+
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
442+
alloc_numel += output_numel
443+
if discard_input:
444+
alloc_numel -= input_numel
445+
446+
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
447+
"""analyze split memory footprint
448+
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
449+
the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not
450+
generate new tensor in this case, so no memory will be allocated.
451+
452+
Args:
453+
comm_spec (CommSpec): input CommSpec
454+
discard_input (bool): whether to discard the input tensor
455+
alloc_numel (int): current allocated numel
456+
peak_numel (int): current peak numel
457+
"""
458+
shard_dim = comm_spec.shard_dim
459+
if shard_dim != 0:
460+
# if we don't shard the tensor on the first dimension, the split action will
461+
# generate a new tensor
462+
input_shape = compute_shape(comm_spec.sharding_spec)
463+
input_numel = np.prod(input_shape)
464+
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axes]
465+
alloc_numel += output_numel
466+
peak_numel = max(peak_numel, alloc_numel)
467+
if discard_input:
468+
alloc_numel -= input_numel
469+
else:
470+
# if we shard the tensor on the first dimension, the split action will not generate
471+
# a new tensor, and as it will preserve a reference to the input tensor, we could
472+
# override the discard_input option here
473+
# NOTE: this special case might fail in some weird cases, e.g. if we have three split
474+
# actions in the comm actions sequence, the first split action operate on the second dimension,
475+
# the second split action operate on the first dimension, and the third split action operate, again,
476+
# on the second dimension. Therefore, after the first two actions in the sequence, we will allocate
477+
# memory the same size as the output of first split action. However, the third split action will discard
478+
# the input tensor, and it actually should discard the tensor generated by the first split action, so in
479+
# the current memory estimation framework, we will overestimate the memory usage. But the above case is
480+
# kind of weird, and I think we could ignore it for now.
481+
pass
482+
483+
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
484+
"""
485+
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
486+
"""
487+
pass
488+
489+
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
490+
"""analyze all_to_all memory footprint
491+
all_to_all will allocate memory for the output tensor, and temp memory of all_to_all action
492+
is twice the size of output tensor if we shard input tensor on the first dimension, otherwise
493+
the temp memory is three times the size of output tensor
494+
495+
Args:
496+
comm_spec (CommSpec): input CommSpec
497+
discard_input (bool): whether to discard the input tensor
498+
alloc_numel (int): current allocated numel
499+
peak_numel (int): current peak numel
500+
"""
501+
input_shape = compute_shape(comm_spec.sharding_spec)
502+
input_numel = np.prod(input_shape)
503+
output_numel = input_numel
504+
shard_dim = comm_spec.shard_dim
505+
if shard_dim != 0:
506+
peak_numel = max(peak_numel, alloc_numel + output_numel * 3)
507+
else:
508+
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
509+
alloc_numel += output_numel
510+
if discard_input:
511+
alloc_numel -= input_numel
512+
513+
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
514+
"""
515+
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
516+
"""
517+
pass
518+
519+
pattern_to_func_dict = {
520+
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
521+
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: [all2all_analysis, all2all_analysis],
522+
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: [split_analysis, gather_analysis],
523+
CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: [reduce_analysis, identity_analysis],
524+
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: [identity_analysis, reduce_analysis],
525+
CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: [],
526+
}
527+
528+
fwd_actions = []
529+
bwd_actions = []
530+
531+
# construct forward and backward comm actions sequence
532+
for comm_spec in comm_action_sequence:
533+
comm_spec: CommSpec
534+
fwd_action, bwd_action = pattern_to_func_dict[comm_spec.comm_pattern]
535+
fwd_actions.append(fwd_action)
536+
bwd_actions.append(bwd_action)
537+
538+
# analyze memory footprint of forward comm actions sequence
539+
fwd_alloc_numel = 0
540+
fwd_peak_numel = 0
541+
for idx, fwd_action, comm_spec in enumerate(zip(fwd_actions, comm_action_sequence)):
542+
# the first forward comm action will not discard input
543+
if idx == 0:
544+
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
545+
else:
546+
fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
547+
548+
# analyze memory footprint for backward comm actions sequence
549+
bwd_alloc_numel = 0
550+
bwd_peak_numel = 0
551+
for idx, bwd_action, comm_spec in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
552+
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
553+
554+
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
555+
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
556+
total_mem = MemoryCost(activation=fwd_alloc_numel + bwd_alloc_numel)
557+
558+
return TrainCycleItem(fwd_mem, bwd_mem, total_mem)
559+
406560
def shape_consistency(self, source_spec: ShardingSpec,
407561
target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
408562
'''

0 commit comments

Comments
 (0)