diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 5e983777b04..bdf5e70e77f 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -382,7 +382,7 @@ void PrecomputedValues::bindTensorMetaData( tensor.dim() == static_cast(logical_domain.size()), "Something went wrong configuring launch. Inputs do not match."); - std::vector logical_sizes = unshardedSizes(tv, tensor.sizes()); + std::vector logical_sizes = unshardedSizes(tv, tensor.sizes(), &extent_to_multiplier_map_); adjustEvaluatorSizes(tv, logical_sizes); for (const auto dim : arange(static_cast(logical_domain.size()))) { diff --git a/csrc/evaluator_common.h b/csrc/evaluator_common.h index aabf029ed4d..c68045535ac 100644 --- a/csrc/evaluator_common.h +++ b/csrc/evaluator_common.h @@ -181,6 +181,11 @@ class PrecomputedValues { return has_valid_values_; } + //! Get the extent to multiplier map for unshardedSizes + std::unordered_map* getExtentToMultiplierMap() { + return &extent_to_multiplier_map_; + } + //! Runs the internal value machine that will compute //! the values allocated in the workspace. void evaluate(); @@ -289,6 +294,9 @@ class PrecomputedValues { //! Stores the IR nodes corresponding to each index. std::vector symbols_; + //! Extent to multiplier map for unshardedSizes - owned by this PrecomputedValues + std::unordered_map extent_to_multiplier_map_; + //! An internal log to keep track of all the bindings //! used in each evaluation cycle. To be used for //! consistency check. diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index 38bbad4fda7..afec1cd2ca5 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -146,7 +146,7 @@ void ExpressionEvaluator::bindTensorDomain( tv->toString(), ", to be bound to a tensor of equal rank."); - std::vector logical_sizes = unshardedSizes(tv, t.sizes()); + std::vector logical_sizes = unshardedSizes(tv, t.sizes(), getExtentToMultiplierMap()); adjustEvaluatorSizes(tv, logical_sizes); for (const auto& [i, id] : enumerate(logical_domain)) { diff --git a/csrc/expr_evaluator.h b/csrc/expr_evaluator.h index 2c7b2fdb2ef..ca4190fea0a 100644 --- a/csrc/expr_evaluator.h +++ b/csrc/expr_evaluator.h @@ -79,6 +79,11 @@ class ExpressionEvaluator { return precomputed_values_; } + //! Get the extent to multiplier map for unshardedSizes from PrecomputedValues + std::unordered_map* getExtentToMultiplierMap() const { + return precomputed_values_ ? precomputed_values_->getExtentToMultiplierMap() : nullptr; + } + //! Augment the evaluator with the exact root-domain map such that //! if the extent of a root ID is known, the extents of all other //! root IDs that are exactly mapped also get bound to the same diff --git a/csrc/multidevice/execution_utils.cpp b/csrc/multidevice/execution_utils.cpp index a7a7da703e2..39a05cdde4a 100644 --- a/csrc/multidevice/execution_utils.cpp +++ b/csrc/multidevice/execution_utils.cpp @@ -50,7 +50,8 @@ at::Tensor shardTensor( std::vector unshardedSizes( const TensorView* tv, - c10::IntArrayRef sizes) { + c10::IntArrayRef sizes, + std::unordered_map* extent_to_multiplier_map) { std::vector unsharded_sizes = sizes.vec(); for (ParallelType parallel_type : deviceAndStreamParallelTypes()) { const DomainType domain_type = parallel_type == ParallelType::Stream @@ -101,6 +102,27 @@ std::vector unshardedSizes( NVF_THROW("Unexpected parallel type: ", parallel_type); }(); + + // Check consistency: for the same extent, we should always get the same multiplier + // Only perform this check if a map is provided + if (extent_to_multiplier_map) { + Val* extent = sharded_id->extent(); + auto it = extent_to_multiplier_map->find(extent); + if (it != extent_to_multiplier_map->end()) { + NVF_ERROR( + it->second == multiplier, + "Inconsistent multiplier for extent ", + extent->toString(), + ": expected ", + it->second, + " but got ", + multiplier); + } else { + (*extent_to_multiplier_map)[extent] = multiplier; + } + } else { + // NVF_ERROR(false, "Extent to multiplier map not provided"); + } unsharded_sizes.at(sharded_axis) *= multiplier; } diff --git a/csrc/multidevice/execution_utils.h b/csrc/multidevice/execution_utils.h index 2032b440f22..99323f73d10 100644 --- a/csrc/multidevice/execution_utils.h +++ b/csrc/multidevice/execution_utils.h @@ -67,6 +67,7 @@ NVF_API at::Tensor shardTensor( // ExpressionEvaluator, and so on, which is an API overhaul. std::vector unshardedSizes( const TensorView* tv, - c10::IntArrayRef sizes); + c10::IntArrayRef sizes, + std::unordered_map* extent_to_multiplier_map = nullptr); } // namespace nvfuser diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index 3c5a9c2fb46..31dce3fd246 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -886,6 +886,7 @@ std::pair, std::vector> inferShapeOfOutput( TensorShapeInfo inferTensorShapes( TensorView* tv, const ExpressionEvaluator& expr_eval) { + auto* extent_map = expr_eval.getExtentToMultiplierMap(); // Alias handling: auto alias_info = tv->fusion()->getOutputAlias(tv); if (alias_info.type != AllocationType::New) { @@ -902,7 +903,7 @@ TensorShapeInfo inferTensorShapes( return TensorShapeInfo{ tensor.sizes().vec(), tensor.strides().vec(), - isSharded(tv) ? unshardedSizes(tv, tensor.sizes().vec()) + isSharded(tv) ? unshardedSizes(tv, tensor.sizes().vec(), extent_map) : std::vector(), }; } @@ -911,7 +912,7 @@ TensorShapeInfo inferTensorShapes( return TensorShapeInfo{ tensor.sizes().vec(), tensor.strides().vec(), - isSharded(tv) ? unshardedSizes(tv, tensor.sizes().vec()) + isSharded(tv) ? unshardedSizes(tv, tensor.sizes().vec(), extent_map) : std::vector(), allocation_size_stride.first, allocation_size_stride.second}; @@ -923,7 +924,7 @@ TensorShapeInfo inferTensorShapes( return TensorShapeInfo{ allocation_size_stride.first, allocation_size_stride.second, - isSharded(tv) ? unshardedSizes(tv, allocation_size_stride.first) + isSharded(tv) ? unshardedSizes(tv, allocation_size_stride.first, extent_map) : std::vector(), }; } @@ -940,7 +941,7 @@ TensorShapeInfo inferTensorShapes( return TensorShapeInfo{ logical_meta_tensor.sizes().vec(), logical_meta_tensor.strides().vec(), - isSharded(tv) ? unshardedSizes(tv, logical_meta_tensor.sizes().vec()) + isSharded(tv) ? unshardedSizes(tv, logical_meta_tensor.sizes().vec(), extent_map) : std::vector(), allocation_size_stride.first, allocation_size_stride.second}; diff --git a/csrc/runtime/executor.cpp b/csrc/runtime/executor.cpp index baa478f095f..271790e7048 100644 --- a/csrc/runtime/executor.cpp +++ b/csrc/runtime/executor.cpp @@ -744,7 +744,7 @@ void KernelExecutor::initializeExecutorEntry( shape_info.logical_strides = arg_tensor.strides().vec(); if (isSharded(input_tv)) { shape_info.unsharded_logical_sizes = - unshardedSizes(input_tv, shape_info.logical_sizes); + unshardedSizes(input_tv, shape_info.logical_sizes, expr_eval.getExtentToMultiplierMap()); } shape_info.allocation_sizes = alloc_sizes; shape_info.allocation_strides = alloc_strides; diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 89d83a97eea..f18b644cf2c 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -302,7 +302,7 @@ inferAllocationSizesAndStrides( const auto& alloc = tv->getMaybeAllocationDomain(); // active IDs and their shape and stride - std::vector logical_sizes = unshardedSizes(tv, tensor.sizes()); + std::vector logical_sizes = unshardedSizes(tv, tensor.sizes(), ee.getExtentToMultiplierMap()); std::unordered_map> active_ids; int64_t dim_index = 0; for (IterDomain* id : logical | TensorDomain::kNoReductions) { @@ -398,7 +398,7 @@ std::vector GetMetaData::evaluate( metadata->data = input.data_ptr(); if (isSharded(tv)) { - std::vector unsharded_sizes = unshardedSizes(tv, input.sizes()); + std::vector unsharded_sizes = unshardedSizes(tv, input.sizes(), ee.getExtentToMultiplierMap()); metadata->logical_size_data = std::move(unsharded_sizes); metadata->logical_size = c10::makeArrayRef(metadata->logical_size_data); } else {