Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ void PrecomputedValues::bindTensorMetaData(
tensor.dim() == static_cast<int64_t>(logical_domain.size()),
"Something went wrong configuring launch. Inputs do not match.");

std::vector<int64_t> logical_sizes = unshardedSizes(tv, tensor.sizes());
std::vector<int64_t> logical_sizes = unshardedSizes(tv, tensor.sizes(), &extent_to_multiplier_map_);
adjustEvaluatorSizes(tv, logical_sizes);

for (const auto dim : arange(static_cast<int64_t>(logical_domain.size()))) {
Expand Down
8 changes: 8 additions & 0 deletions csrc/evaluator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ class PrecomputedValues {
return has_valid_values_;
}

//! Get the extent to multiplier map for unshardedSizes
std::unordered_map<Val*, int64_t>* getExtentToMultiplierMap() {
return &extent_to_multiplier_map_;
}

//! Runs the internal value machine that will compute
//! the values allocated in the workspace.
void evaluate();
Expand Down Expand Up @@ -289,6 +294,9 @@ class PrecomputedValues {
//! Stores the IR nodes corresponding to each index.
std::vector<Val*> symbols_;

//! Extent to multiplier map for unshardedSizes - owned by this PrecomputedValues
std::unordered_map<Val*, int64_t> extent_to_multiplier_map_;

//! An internal log to keep track of all the bindings
//! used in each evaluation cycle. To be used for
//! consistency check.
Expand Down
2 changes: 1 addition & 1 deletion csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void ExpressionEvaluator::bindTensorDomain(
tv->toString(),
", to be bound to a tensor of equal rank.");

std::vector<int64_t> logical_sizes = unshardedSizes(tv, t.sizes());
std::vector<int64_t> logical_sizes = unshardedSizes(tv, t.sizes(), getExtentToMultiplierMap());
adjustEvaluatorSizes(tv, logical_sizes);

for (const auto& [i, id] : enumerate(logical_domain)) {
Expand Down
5 changes: 5 additions & 0 deletions csrc/expr_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class ExpressionEvaluator {
return precomputed_values_;
}

//! Get the extent to multiplier map for unshardedSizes from PrecomputedValues
std::unordered_map<Val*, int64_t>* getExtentToMultiplierMap() const {
return precomputed_values_ ? precomputed_values_->getExtentToMultiplierMap() : nullptr;
}

//! Augment the evaluator with the exact root-domain map such that
//! if the extent of a root ID is known, the extents of all other
//! root IDs that are exactly mapped also get bound to the same
Expand Down
24 changes: 23 additions & 1 deletion csrc/multidevice/execution_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ at::Tensor shardTensor(

std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes) {
c10::IntArrayRef sizes,
std::unordered_map<Val*, int64_t>* extent_to_multiplier_map) {
std::vector<int64_t> unsharded_sizes = sizes.vec();
for (ParallelType parallel_type : deviceAndStreamParallelTypes()) {
const DomainType domain_type = parallel_type == ParallelType::Stream
Expand Down Expand Up @@ -101,6 +102,27 @@ std::vector<int64_t> unshardedSizes(

NVF_THROW("Unexpected parallel type: ", parallel_type);
}();

// Check consistency: for the same extent, we should always get the same multiplier
// Only perform this check if a map is provided
if (extent_to_multiplier_map) {
Val* extent = sharded_id->extent();
auto it = extent_to_multiplier_map->find(extent);
if (it != extent_to_multiplier_map->end()) {
NVF_ERROR(
it->second == multiplier,
"Inconsistent multiplier for extent ",
extent->toString(),
": expected ",
it->second,
" but got ",
multiplier);
} else {
(*extent_to_multiplier_map)[extent] = multiplier;
}
} else {
// NVF_ERROR(false, "Extent to multiplier map not provided");
}
unsharded_sizes.at(sharded_axis) *= multiplier;
}

Expand Down
3 changes: 2 additions & 1 deletion csrc/multidevice/execution_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ NVF_API at::Tensor shardTensor(
// ExpressionEvaluator, and so on, which is an API overhaul.
std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes);
c10::IntArrayRef sizes,
std::unordered_map<Val*, int64_t>* extent_to_multiplier_map = nullptr);

} // namespace nvfuser
9 changes: 5 additions & 4 deletions csrc/runtime/allocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,7 @@ std::pair<std::vector<int64_t>, std::vector<int64_t>> inferShapeOfOutput(
TensorShapeInfo inferTensorShapes(
TensorView* tv,
const ExpressionEvaluator& expr_eval) {
auto* extent_map = expr_eval.getExtentToMultiplierMap();
// Alias handling:
auto alias_info = tv->fusion()->getOutputAlias(tv);
if (alias_info.type != AllocationType::New) {
Expand All @@ -902,7 +903,7 @@ TensorShapeInfo inferTensorShapes(
return TensorShapeInfo{
tensor.sizes().vec(),
tensor.strides().vec(),
isSharded(tv) ? unshardedSizes(tv, tensor.sizes().vec())
isSharded(tv) ? unshardedSizes(tv, tensor.sizes().vec(), extent_map)
: std::vector<int64_t>(),
};
}
Expand All @@ -911,7 +912,7 @@ TensorShapeInfo inferTensorShapes(
return TensorShapeInfo{
tensor.sizes().vec(),
tensor.strides().vec(),
isSharded(tv) ? unshardedSizes(tv, tensor.sizes().vec())
isSharded(tv) ? unshardedSizes(tv, tensor.sizes().vec(), extent_map)
: std::vector<int64_t>(),
allocation_size_stride.first,
allocation_size_stride.second};
Expand All @@ -923,7 +924,7 @@ TensorShapeInfo inferTensorShapes(
return TensorShapeInfo{
allocation_size_stride.first,
allocation_size_stride.second,
isSharded(tv) ? unshardedSizes(tv, allocation_size_stride.first)
isSharded(tv) ? unshardedSizes(tv, allocation_size_stride.first, extent_map)
: std::vector<int64_t>(),
};
}
Expand All @@ -940,7 +941,7 @@ TensorShapeInfo inferTensorShapes(
return TensorShapeInfo{
logical_meta_tensor.sizes().vec(),
logical_meta_tensor.strides().vec(),
isSharded(tv) ? unshardedSizes(tv, logical_meta_tensor.sizes().vec())
isSharded(tv) ? unshardedSizes(tv, logical_meta_tensor.sizes().vec(), extent_map)
: std::vector<int64_t>(),
allocation_size_stride.first,
allocation_size_stride.second};
Expand Down
2 changes: 1 addition & 1 deletion csrc/runtime/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ void KernelExecutor::initializeExecutorEntry(
shape_info.logical_strides = arg_tensor.strides().vec();
if (isSharded(input_tv)) {
shape_info.unsharded_logical_sizes =
unshardedSizes(input_tv, shape_info.logical_sizes);
unshardedSizes(input_tv, shape_info.logical_sizes, expr_eval.getExtentToMultiplierMap());
}
shape_info.allocation_sizes = alloc_sizes;
shape_info.allocation_strides = alloc_strides;
Expand Down
4 changes: 2 additions & 2 deletions csrc/tensor_metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ inferAllocationSizesAndStrides(
const auto& alloc = tv->getMaybeAllocationDomain();

// active IDs and their shape and stride
std::vector<int64_t> logical_sizes = unshardedSizes(tv, tensor.sizes());
std::vector<int64_t> logical_sizes = unshardedSizes(tv, tensor.sizes(), ee.getExtentToMultiplierMap());
std::unordered_map<IterDomain*, std::pair<int64_t, int64_t>> active_ids;
int64_t dim_index = 0;
for (IterDomain* id : logical | TensorDomain::kNoReductions) {
Expand Down Expand Up @@ -398,7 +398,7 @@ std::vector<PolymorphicValue> GetMetaData::evaluate(
metadata->data = input.data_ptr();

if (isSharded(tv)) {
std::vector<int64_t> unsharded_sizes = unshardedSizes(tv, input.sizes());
std::vector<int64_t> unsharded_sizes = unshardedSizes(tv, input.sizes(), ee.getExtentToMultiplierMap());
metadata->logical_size_data = std::move(unsharded_sizes);
metadata->logical_size = c10::makeArrayRef(metadata->logical_size_data);
} else {
Expand Down
Loading