|
3 | 3 | from dataclasses import dataclass |
4 | 4 | from typing import Dict, List, Tuple |
5 | 5 |
|
| 6 | +import numpy as np |
6 | 7 | import torch |
7 | 8 |
|
| 9 | +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem |
8 | 10 | from colossalai.context.singleton_meta import SingletonMeta |
9 | 11 | from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException |
10 | 12 | from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator |
@@ -403,6 +405,158 @@ def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_d |
403 | 405 | valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict)) |
404 | 406 | return valid_spec_dict |
405 | 407 |
|
| 408 | + def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem: |
| 409 | + """memory cost of the communication action sequence |
| 410 | + TODO: Currently we just consider tensor numel in the shape consistency manger, |
| 411 | + as the manager itself doesn't have the access to tensor dtype, we need to take |
| 412 | + it into consideration in memory estimation. |
| 413 | +
|
| 414 | + Args: |
| 415 | + comm_action_sequence (List[CommSpec]): list of communication actions |
| 416 | +
|
| 417 | + Returns: |
| 418 | + TrainCycleItem: memory (numel) cost of such comm_action_sequence |
| 419 | + """ |
| 420 | + |
| 421 | + def compute_shape(sharding_spec: ShardingSpec): |
| 422 | + shape = sharding_spec.entire_shape |
| 423 | + for dim, shard in sharding_spec.dim_partition_dict.items(): |
| 424 | + shape[dim] = shape[dim] // len(shard) |
| 425 | + return shape |
| 426 | + |
| 427 | + def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): |
| 428 | + """analyze all_gather memory footprint |
| 429 | + all_gather will allocate memory for the output tensor, and there will be temp memory for |
| 430 | + all_gather operation, which is twice the size of output tensor |
| 431 | +
|
| 432 | + Args: |
| 433 | + comm_spec (CommSpec): input CommSpec |
| 434 | + discard_input (bool): whether to discard the input tensor |
| 435 | + alloc_numel (int): current allocated numel |
| 436 | + peak_numel (int): current peak numel |
| 437 | + """ |
| 438 | + input_shape = compute_shape(comm_spec.sharding_spec) |
| 439 | + input_numel = np.prod(input_shape) |
| 440 | + output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] |
| 441 | + peak_numel = max(peak_numel, alloc_numel + output_numel * 2) |
| 442 | + alloc_numel += output_numel |
| 443 | + if discard_input: |
| 444 | + alloc_numel -= input_numel |
| 445 | + |
| 446 | + def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): |
| 447 | + """analyze split memory footprint |
| 448 | + split will allocate memory for the output tensor if we don't apply shard on the first dimension of |
| 449 | + the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not |
| 450 | + generate new tensor in this case, so no memory will be allocated. |
| 451 | +
|
| 452 | + Args: |
| 453 | + comm_spec (CommSpec): input CommSpec |
| 454 | + discard_input (bool): whether to discard the input tensor |
| 455 | + alloc_numel (int): current allocated numel |
| 456 | + peak_numel (int): current peak numel |
| 457 | + """ |
| 458 | + shard_dim = comm_spec.shard_dim |
| 459 | + if shard_dim != 0: |
| 460 | + # if we don't shard the tensor on the first dimension, the split action will |
| 461 | + # generate a new tensor |
| 462 | + input_shape = compute_shape(comm_spec.sharding_spec) |
| 463 | + input_numel = np.prod(input_shape) |
| 464 | + output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axes] |
| 465 | + alloc_numel += output_numel |
| 466 | + peak_numel = max(peak_numel, alloc_numel) |
| 467 | + if discard_input: |
| 468 | + alloc_numel -= input_numel |
| 469 | + else: |
| 470 | + # if we shard the tensor on the first dimension, the split action will not generate |
| 471 | + # a new tensor, and as it will preserve a reference to the input tensor, we could |
| 472 | + # override the discard_input option here |
| 473 | + # NOTE: this special case might fail in some weird cases, e.g. if we have three split |
| 474 | + # actions in the comm actions sequence, the first split action operate on the second dimension, |
| 475 | + # the second split action operate on the first dimension, and the third split action operate, again, |
| 476 | + # on the second dimension. Therefore, after the first two actions in the sequence, we will allocate |
| 477 | + # memory the same size as the output of first split action. However, the third split action will discard |
| 478 | + # the input tensor, and it actually should discard the tensor generated by the first split action, so in |
| 479 | + # the current memory estimation framework, we will overestimate the memory usage. But the above case is |
| 480 | + # kind of weird, and I think we could ignore it for now. |
| 481 | + pass |
| 482 | + |
| 483 | + def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): |
| 484 | + """ |
| 485 | + a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory |
| 486 | + """ |
| 487 | + pass |
| 488 | + |
| 489 | + def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): |
| 490 | + """analyze all_to_all memory footprint |
| 491 | + all_to_all will allocate memory for the output tensor, and temp memory of all_to_all action |
| 492 | + is twice the size of output tensor if we shard input tensor on the first dimension, otherwise |
| 493 | + the temp memory is three times the size of output tensor |
| 494 | +
|
| 495 | + Args: |
| 496 | + comm_spec (CommSpec): input CommSpec |
| 497 | + discard_input (bool): whether to discard the input tensor |
| 498 | + alloc_numel (int): current allocated numel |
| 499 | + peak_numel (int): current peak numel |
| 500 | + """ |
| 501 | + input_shape = compute_shape(comm_spec.sharding_spec) |
| 502 | + input_numel = np.prod(input_shape) |
| 503 | + output_numel = input_numel |
| 504 | + shard_dim = comm_spec.shard_dim |
| 505 | + if shard_dim != 0: |
| 506 | + peak_numel = max(peak_numel, alloc_numel + output_numel * 3) |
| 507 | + else: |
| 508 | + peak_numel = max(peak_numel, alloc_numel + output_numel * 2) |
| 509 | + alloc_numel += output_numel |
| 510 | + if discard_input: |
| 511 | + alloc_numel -= input_numel |
| 512 | + |
| 513 | + def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): |
| 514 | + """ |
| 515 | + a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory |
| 516 | + """ |
| 517 | + pass |
| 518 | + |
| 519 | + pattern_to_func_dict = { |
| 520 | + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis], |
| 521 | + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: [all2all_analysis, all2all_analysis], |
| 522 | + CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: [split_analysis, gather_analysis], |
| 523 | + CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: [reduce_analysis, identity_analysis], |
| 524 | + CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: [identity_analysis, reduce_analysis], |
| 525 | + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: [], |
| 526 | + } |
| 527 | + |
| 528 | + fwd_actions = [] |
| 529 | + bwd_actions = [] |
| 530 | + |
| 531 | + # construct forward and backward comm actions sequence |
| 532 | + for comm_spec in comm_action_sequence: |
| 533 | + comm_spec: CommSpec |
| 534 | + fwd_action, bwd_action = pattern_to_func_dict[comm_spec.comm_pattern] |
| 535 | + fwd_actions.append(fwd_action) |
| 536 | + bwd_actions.append(bwd_action) |
| 537 | + |
| 538 | + # analyze memory footprint of forward comm actions sequence |
| 539 | + fwd_alloc_numel = 0 |
| 540 | + fwd_peak_numel = 0 |
| 541 | + for idx, fwd_action, comm_spec in enumerate(zip(fwd_actions, comm_action_sequence)): |
| 542 | + # the first forward comm action will not discard input |
| 543 | + if idx == 0: |
| 544 | + fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel) |
| 545 | + else: |
| 546 | + fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel) |
| 547 | + |
| 548 | + # analyze memory footprint for backward comm actions sequence |
| 549 | + bwd_alloc_numel = 0 |
| 550 | + bwd_peak_numel = 0 |
| 551 | + for idx, bwd_action, comm_spec in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): |
| 552 | + bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel) |
| 553 | + |
| 554 | + fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) |
| 555 | + bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel) |
| 556 | + total_mem = MemoryCost(activation=fwd_alloc_numel + bwd_alloc_numel) |
| 557 | + |
| 558 | + return TrainCycleItem(fwd_mem, bwd_mem, total_mem) |
| 559 | + |
406 | 560 | def shape_consistency(self, source_spec: ShardingSpec, |
407 | 561 | target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]: |
408 | 562 | ''' |
|
0 commit comments