Skip to content

Commit cce8875

Browse files
protonunaoyamgreptile-apps[bot]
authored
Extend the block quantization op to have optional swizzled block scales output (#5554)
1. Extends the runtime function. - Extended block_quantize_to_nvfp4 function with optional swizzling parameters - Added support for 5D allocation domain swizzling with specific pattern: [m/128, k/4, 32(m_i), 4(m_o), 4(k)] - Implemented swizzled address calculation logic for optimal memory layout as per Blackwell documentation 2. Relaxes checks in validation to allows for allocation domain in the block scales output. 3. Extend codegen to pass the allocation domain extents to the runtime function. --------- Co-authored-by: Naoya Maruyama <[email protected]> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 722fc4f commit cce8875

File tree

4 files changed

+399
-74
lines changed

4 files changed

+399
-74
lines changed

csrc/codegen.cpp

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,70 +1807,83 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
18071807
}
18081808

18091809
// Special handling of BlockQuantizationOp to call the runtime function.
1810-
// TODO: add support for global scaling factor
18111810
void handle(const BlockQuantizationOp* bqop) final {
18121811
// This operator is plumbed down to a runtime function call.
18131812
// One of the assumptions is that the device runtime expects
18141813
// n consecutive inputs per thread. Where n can be 2 or 4 for Float, and 2,
18151814
// 4, or 8 for Half. We achieve this by having the quantized output tv
18161815
// scheduled to have the inner dimension grouped by 2/4/8.
18171816
auto output = bqop->quantizedOutput()->as<kir::TensorIndex>()->view();
1818-
int64_t group_size = 1;
18191817

1820-
// Get the loop domain of the TensorView output and check for group
1821-
// parallel types. This assumes that both parallel types aren't present.
1818+
// Extract group size from the loop domain
1819+
int64_t group_size = 1;
18221820
const auto& loop_domain = output->getLoopDomain();
1823-
for (auto* domain : loop_domain) {
1824-
auto parallel_type = domain->getParallelType();
1825-
if (parallel_type == ParallelType::Group) {
1826-
if (domain->extent()->isConstInt()) {
1827-
group_size = domain->extent()->evaluate().as<int64_t>();
1828-
}
1821+
for (const auto* domain : loop_domain) {
1822+
if (domain->getParallelType() == ParallelType::Group &&
1823+
domain->extent()->isConstInt()) {
1824+
group_size = domain->extent()->evaluate().as<int64_t>();
1825+
break;
18291826
}
18301827
}
18311828

1832-
auto input_dtype =
1833-
bqop->in()->as<kir::TensorIndex>()->view()->getDataType();
1834-
1835-
if (input_dtype == DataType::BFloat16 || input_dtype == DataType::Half) {
1836-
NVF_ERROR(
1837-
group_size == 8 || group_size == 4 || group_size == 2,
1838-
"Group size should be 2, 4 or 8 for "
1839-
"BlockQuantizationOp: ",
1840-
bqop->toString());
1841-
1842-
} else {
1843-
NVF_ERROR(
1844-
group_size == 4 || group_size == 2,
1845-
"Group size should be 2 or 4 for "
1846-
"BlockQuantizationOp: ",
1847-
bqop->toString());
1848-
}
1829+
// Validate group size based on input data type
1830+
const auto input_dtype =
1831+
bqop->in()->as<kir::TensorIndex>()->view()->getDataType().value();
1832+
const bool is_half_precision =
1833+
(input_dtype == DataType::BFloat16 || input_dtype == DataType::Half);
1834+
const bool is_valid_group_size = is_half_precision
1835+
? (group_size == 2 || group_size == 4 || group_size == 8)
1836+
: (group_size == 2 || group_size == 4);
18491837

1838+
NVF_ERROR(
1839+
is_valid_group_size,
1840+
"Group size should be ",
1841+
is_half_precision ? "2, 4 or 8" : "2 or 4",
1842+
" for BlockQuantizationOp with input type ",
1843+
input_dtype,
1844+
". Found: ",
1845+
group_size,
1846+
". Expr: ",
1847+
bqop->toString());
1848+
1849+
// Build template arguments
18501850
ArgumentBuilder template_args;
1851-
template_args.arg(
1852-
bqop->hasGlobalScale() ? true : false); // HAS_GLOBAL_SCALE
1851+
template_args.arg(bqop->hasGlobalScale()); // HAS_GLOBAL_SCALE
18531852
template_args.arg(group_size); // ITEMS_PER_THREAD
18541853

1855-
// Function arguments
1854+
// Build function arguments
18561855
ArgumentBuilder func_args;
1857-
1858-
// First argument: input data array
1859-
// Second argument: quantized output
1860-
// Third argument: block scale output
1861-
func_args.arg(genInline(bqop->input(0)->as<kir::TensorIndex>()->view()));
1862-
func_args.arg(genInline(output));
1863-
func_args.arg(
1864-
genInline(bqop->blockScales()->as<kir::TensorIndex>()->view()));
1865-
1866-
// Fourth argument: This holds the linearized index that will be used to
1867-
// write out the block scaling factors in the runtime function.
1868-
func_args.arg(genInline(bqop->attributeVal(0)));
1869-
1870-
// Fifth argument: global scale (if any)
1856+
func_args.arg(genInline(
1857+
bqop->input(0)->as<kir::TensorIndex>()->view())); // input data
1858+
func_args.arg(genInline(output)); // quantized output
1859+
func_args.arg(genInline(
1860+
bqop->blockScales()->as<kir::TensorIndex>()->view())); // block scales
1861+
func_args.arg(genInline(
1862+
bqop->attributeVal(0))); // linearized index for runtime function
18711863
func_args.arg(
18721864
bqop->hasGlobalScale() ? genInline(bqop->globalScale()) : "{}");
18731865

1866+
// Add swizzled allocation domain parameters if needed
1867+
auto block_scales_tv = bqop->blockScales()->as<kir::TensorIndex>()->view();
1868+
if (block_scales_tv->hasAllocation()) {
1869+
auto logical_domain =
1870+
TensorDomain::noReductions(block_scales_tv->getLogicalDomain());
1871+
auto allocation_domain =
1872+
TensorDomain::noReductions(block_scales_tv->getAllocationDomain());
1873+
1874+
// Swizzled layout: 2D logical -> 5D allocation
1875+
if (logical_domain.size() == 2 && allocation_domain.size() == 5) {
1876+
// Add logical domain extent of the inner dimension
1877+
func_args.arg(genInline(logical_domain[1]->extent()));
1878+
1879+
// Add all allocation domain extents
1880+
for (const auto* alloc_id : allocation_domain) {
1881+
func_args.arg(genInline(alloc_id->extent()));
1882+
}
1883+
}
1884+
}
1885+
1886+
// Generate the function call
18741887
indent() << genCall("bq::block_quantize_to_nvfp4", template_args, func_args)
18751888
<< ";\n";
18761889
}

csrc/device_lower/validation.cpp

Lines changed: 147 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,147 @@ bool isInnermost(IterDomain* base_id, IterDomain* maybe_innermost_id) {
274274
return !frontier.empty() && frontier.back() == maybe_innermost_id;
275275
}
276276

277+
// Validate the swizzling pattern:
278+
// We support a very restricted pattern from 2D logical to 5D allocation
279+
// Expected pattern:
280+
// m, k -> m, k/4, 4 (split k by 4)
281+
// m, k/4, 4 -> m/128, 128, k/4, 4 (split m by 128)
282+
// m/128, 128, k/4, 4 -> m/128, 4(m_o), 32(m_i), k/4, 4 (split 128 by 32)
283+
// Then reorder to: m/128, k/4, 32(m_i), 4(m_o), 4(k)
284+
void isValidBlockScaleSwizzle(TensorView* block_scale) {
285+
auto logical_domain =
286+
TensorDomain::noReductions(block_scale->getLogicalDomain());
287+
auto allocation_domain =
288+
TensorDomain::noReductions(block_scale->getAllocationDomain());
289+
290+
// check that size of logical domain is 2 and allocation domain is 5
291+
NVF_ERROR(
292+
logical_domain.size() == 2 && allocation_domain.size() == 5,
293+
"Block scale swizzle must have 2D logical domain and 5D allocation "
294+
"domain. Found: ",
295+
logical_domain.size(),
296+
"D logical and ",
297+
allocation_domain.size(),
298+
"D allocation for TensorView: ",
299+
block_scale->toString());
300+
301+
// keep count of splits
302+
int num_splits = 0;
303+
304+
// keeps track of the split
305+
// M -> M/128, 128
306+
Split* middle_split = nullptr;
307+
308+
// A lambda to check the transforms from logical to allocation domain
309+
// Each transform must be a split, and there can be only 3 splits.
310+
auto check_transform = [block_scale,
311+
&logical_domain,
312+
&num_splits,
313+
&middle_split](Expr* expr) {
314+
if (auto split_expr = dynamic_cast<Split*>(expr)) {
315+
// Can have a max of 3 splits - checked later
316+
num_splits++;
317+
318+
// If expr and it's input is logical_domain back()
319+
// the inner split output should have an extent of 4.
320+
// Check K -> K/4, 4
321+
if (split_expr->in() == logical_domain.back()) {
322+
NVF_ERROR(
323+
split_expr->inner()->extent()->isConstInt() &&
324+
split_expr->inner()->extent()->evaluate().as<int64_t>() == 4,
325+
"The innermost split in block scale swizzle must have an extent of "
326+
"4. "
327+
"Found extent: ",
328+
split_expr->inner()->extent()->toString(),
329+
" in expr: ",
330+
expr->toString(),
331+
" for TensorView: ",
332+
block_scale->toString());
333+
} else if (split_expr->in() == logical_domain.front()) {
334+
// Check M -> M/128, 128
335+
NVF_ERROR(
336+
split_expr->inner()->extent()->isConstInt() &&
337+
split_expr->inner()->extent()->evaluate().as<int64_t>() == 128,
338+
"The outermost split in block scale swizzle must have an extent of "
339+
"128. "
340+
"Found extent: ",
341+
split_expr->inner()->extent()->toString(),
342+
" in expr: ",
343+
expr->toString(),
344+
" for TensorView: ",
345+
block_scale->toString());
346+
347+
// Cache the M -> M/128, 128 split
348+
middle_split = split_expr;
349+
} else {
350+
// Check that the input to this split is the inner output of
351+
// middle_split. As we should have 128 -> 4, 32
352+
NVF_ERROR(
353+
middle_split && split_expr->in() == middle_split->inner(),
354+
"The third split in block scale swizzle must split the inner "
355+
"output "
356+
"(extent 128) of the second split. Expected input to be the inner "
357+
"output "
358+
"of the M/128, 128 split. Found expr: ",
359+
split_expr->toString(),
360+
" for TensorView: ",
361+
block_scale->toString());
362+
363+
NVF_ERROR(
364+
split_expr->inner()->extent()->isConstInt() &&
365+
split_expr->inner()->extent()->evaluate().as<int64_t>() == 32,
366+
"The third split in block scale swizzle (128 -> 4, 32) must have "
367+
"an "
368+
"inner extent of 32. "
369+
"Found extent: ",
370+
split_expr->inner()->extent()->toString(),
371+
" in expr: ",
372+
split_expr->toString(),
373+
" for TensorView: ",
374+
block_scale->toString());
375+
}
376+
} else {
377+
NVF_THROW(
378+
"Logical to allocation domain transforms for block scale swizzle "
379+
"can only contain split operations");
380+
}
381+
};
382+
383+
// Get all exprs between logical and allocation domain
384+
auto transform_exprs = DependencyCheck::getAllExprsBetween(
385+
{logical_domain.begin(), logical_domain.end()},
386+
{allocation_domain.begin(), allocation_domain.end()});
387+
388+
std::vector<IterDomain*> ids_to_transform = logical_domain;
389+
390+
// Transform the logical domain to the allocation domain
391+
// without the permutation.
392+
scheduler_utils::applyTransforms(
393+
ids_to_transform, transform_exprs, check_transform);
394+
395+
// Check that there are exactly 3 splits
396+
NVF_ERROR_EQ(
397+
num_splits,
398+
3,
399+
"Block scale swizzle must have exactly 3 splits. Found ",
400+
num_splits,
401+
" splits in TensorView: ",
402+
block_scale->toString());
403+
404+
// Get the permutation.
405+
auto permutation =
406+
ir_utils::computePermutation(ids_to_transform, allocation_domain);
407+
408+
// m/128, 4(m_o), 32(m_i), k/4, 4(k)
409+
// -> m/128, k/4, 32(m_i), 4(m_o), 4(k)
410+
// check that permutation has a value and it is 0, 3, 2, 1, 4
411+
NVF_ERROR(
412+
permutation.has_value() &&
413+
permutation.value() == std::vector<int64_t>({0, 3, 2, 1, 4}),
414+
"Block scale swizzle permutation is invalid for TensorView: ",
415+
block_scale->toString());
416+
}
417+
277418
// Expr-specific validaion
278419
//
279420
// TODO: Move individual validations to here, e.g.,
@@ -515,15 +656,15 @@ class ExprValidator : public OptOutDispatch {
515656
!quantized_output->hasAllocation(),
516657
"Quantized output must not have an allocation domain.");
517658

518-
// TODO: Relax these for swizzled block scaling factor outputs
519-
// When scaling will be swizzled we will need to allow these checks
520-
// to be relaxed, but we will need to ensure that the swizzling
659+
// When output scales is swizzled we will need to allow these checks
660+
// to be relaxed. We will need to ensure that the swizzling
521661
// allocation allowed is a fixed pattern:
522662
// 2D logical and 5D allocation domain.
523663
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#scale-factor-layouts
524-
NVF_ERROR(
525-
!block_scaling_factor->hasAllocation(),
526-
"Block scaling factor must not have an allocation domain.");
664+
if (block_scaling_factor->hasAllocation()) {
665+
isValidBlockScaleSwizzle(block_scaling_factor);
666+
}
667+
527668
NVF_ERROR(
528669
std::all_of(
529670
block_scaling_factor->getContiguity().begin(),

runtime/block_quantization_kernels.cu

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,13 @@ __device__ void block_quantize_to_nvfp4(
5252
Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output,
5353
Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales,
5454
nvfuser_index_t logical_index,
55-
Tensor<float, 0, 0> global_scale) {
55+
Tensor<float, 0, 0> global_scale,
56+
int64_t fp8_scaling_factors_inner_dim = -1,
57+
int64_t alloc_dim0 = -1,
58+
int64_t alloc_dim1 = -1,
59+
int64_t alloc_dim2 = -1,
60+
int64_t alloc_dim3 = -1,
61+
int64_t alloc_dim4 = -1) {
5662
constexpr bool is_half_or_bfloat =
5763
std::is_same<T, __bfloat>::value || std::is_same<T, __half>::value;
5864
constexpr bool is_float = std::is_same<T, float>::value;
@@ -124,6 +130,34 @@ __device__ void block_quantize_to_nvfp4(
124130
// Only one block scaling factor is written out per 16(assumed block size)
125131
// elements.
126132
int offset = logical_index / 16;
133+
134+
if (fp8_scaling_factors_inner_dim > 0) {
135+
auto stride_4 = 1;
136+
auto stride_3 = stride_4 * alloc_dim4;
137+
auto stride_2 = stride_3 * alloc_dim3;
138+
auto stride_1 = stride_2 * alloc_dim2;
139+
auto stride_0 = stride_1 * alloc_dim1;
140+
141+
auto logical_inner = offset % fp8_scaling_factors_inner_dim;
142+
auto logical_outer = offset / fp8_scaling_factors_inner_dim;
143+
144+
// The allocation domain swizzle logic is:
145+
// m, k -> m, k/4, 4
146+
// m, k/4, 4 -> m/128, 128, k/4, 4 ->
147+
// m/128, 4(m), 32, k/4, 4(k) ->
148+
// m/128, k/4, 32, 4(m), 4(k)
149+
150+
auto pos_4 = logical_inner % 4;
151+
auto pos_1 = logical_inner / 4;
152+
auto pos_t = logical_outer % 128;
153+
auto pos_0 = logical_outer / 128;
154+
auto pos_3 = pos_t / 32;
155+
auto pos_2 = pos_t % 32;
156+
157+
offset = pos_4 * stride_4 + pos_3 * stride_3 + pos_2 * stride_2 +
158+
pos_1 * stride_1 + pos_0 * stride_0;
159+
}
160+
127161
if (threadIdx.x % THREADS_PER_SCALING_FACTOR == 0) {
128162
block_scales[offset] = clamped_max_fp8;
129163
}

0 commit comments

Comments
 (0)