Skip to content

Commit 1f78534

Browse files
Jacenty-And-Intelermilindwalekar
authored andcommitted
Enable quantization for FP4 type
1 parent b63a41c commit 1f78534

File tree

7 files changed

+369
-288
lines changed

7 files changed

+369
-288
lines changed

mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h

Lines changed: 148 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ class QuantizedType : public Type {
9797
return -getDefaultMaximumForF8E5M2();
9898
}
9999

100+
static constexpr int64_t getDefaultMaximumForF4E2M1FN() { return 6; }
101+
102+
static constexpr int64_t getDefaultMinimumForF4E2M1FN() {
103+
return -getDefaultMaximumForF4E2M1FN();
104+
}
105+
100106
/// Gets the original expressed type that this quantized type approximates.
101107
/// Note that this presumes that the quantized type was always derived from
102108
/// a floating point type, which in the broadest definition, is not true (i.e.
@@ -267,7 +273,7 @@ class AnyQuantizedType
267273
/// Per-layer, optional parameters omitted:
268274
/// !quant<uniform[StorageType]{Scale}>
269275
///
270-
/// StorageType: 'i'|'u' NumBits
276+
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
271277
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
272278
/// Scale: A legal double value
273279
/// ZeroPoint: An integer value
@@ -327,7 +333,7 @@ class UniformQuantizedType
327333
/// Per-axis, optional parameters omitted:
328334
/// !quant<uniform[StorageType]{Scale}>
329335
///
330-
/// StorageType: 'i'|'u' NumBits
336+
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
331337
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
332338
/// QuantizedDim: An integer value
333339
/// QuantParams: (Scale ':' ZeroPoint)+
@@ -401,149 +407,6 @@ class UniformQuantizedPerAxisType
401407
}
402408
};
403409

404-
/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a
405-
/// look up table array of quantile values. The type of the data in the look up
406-
/// table is determined by the quantileType member: supported quantileType types
407-
/// are integer/unsigned/hf8/bf8/f16/bf16/f32/f64.
408-
///
409-
/// Syntax synopsis:
410-
/// Per-layer, all parameters expressed:
411-
/// !quant<quantile[StorageType:QuantileType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
412-
/// Per-layer, optional parameters omitted:
413-
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
414-
///
415-
/// StorageType: 'i'|'u' NumBits
416-
/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
417-
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
418-
/// Quantiles: Quantile+
419-
/// Quantile: A legal double value
420-
/// Scale: A legal double value
421-
/// ZeroPoint: An integer value
422-
class QuantileQuantizedType
423-
: public Type::TypeBase<QuantileQuantizedType, UniformQuantizedType,
424-
detail::QuantileQuantizedTypeStorage> {
425-
public:
426-
using Base::Base;
427-
using Base::getChecked;
428-
429-
static constexpr StringLiteral name = "quant.quantile";
430-
431-
/// Gets an instance of the type with all parameters specified but not
432-
/// checked.
433-
static QuantileQuantizedType get(unsigned flags, Type storageType,
434-
Type quantileType, Type expressedType,
435-
ArrayRef<double> quantiles, double scale,
436-
int64_t zeroPoint, int64_t storageTypeMin,
437-
int64_t storageTypeMax);
438-
439-
static QuantileQuantizedType
440-
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
441-
Type storageType, Type quantileType, Type expressedType,
442-
ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
443-
int64_t storageTypeMin, int64_t storageTypeMax);
444-
445-
/// Verifies construction invariants and issues errors/warnings.
446-
static LogicalResult
447-
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
448-
Type storageType, Type quantileType, Type expressedType,
449-
ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
450-
int64_t storageTypeMin, int64_t storageTypeMax);
451-
452-
static bool classof(mlir::Type type);
453-
454-
/// Gets the quantileType
455-
Type getQuantileType() const;
456-
457-
/// Gets the quantileType bit width
458-
unsigned getQuantileTypeIntegralWidth() const;
459-
460-
/// Gets the quantile values
461-
ArrayRef<double> getQuantiles() const;
462-
463-
// Fixed point values are real numbers divided by a scale.
464-
// Currently, only signed storage types are treated as fixed point.
465-
// A fixed point value can be obtained from an affine value by subtracting
466-
// the zeroPoint.
467-
// In the future, this may be explicit versus implied by type and zeroPoint.
468-
bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
469-
};
470-
471-
/// Represents per-axis QuantileQuantizedType (also known as per-channel
472-
/// quantization). The type of the data in the look up table is determined by
473-
/// the quantileType member: supported quantileType types are
474-
/// integer/unsigned/hf8/bf8/f16/bf16/f32/f64.
475-
///
476-
/// Syntax synopsis:
477-
/// Per-axis, all parameters expressed:
478-
/// !quant<quantile[StorageType:QuantileType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
479-
/// Per-axis, optional parameters omitted:
480-
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
481-
///
482-
/// StorageType: 'i'|'u' NumBits
483-
/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64'
484-
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
485-
/// QuantizedDim: An integer value
486-
/// Quantiles: Quantile+
487-
/// Quantile: A legal double value
488-
/// QuantParams: (Scale ':' ZeroPoint)+
489-
/// Scale: A legal double value
490-
/// ZeroPoint: An integer value
491-
class QuantileQuantizedPerAxisType
492-
: public Type::TypeBase<QuantileQuantizedPerAxisType,
493-
UniformQuantizedPerAxisType,
494-
detail::QuantileQuantizedPerAxisTypeStorage> {
495-
public:
496-
using Base::Base;
497-
using Base::getChecked;
498-
499-
static constexpr StringLiteral name = "quant.quantile_per_axis";
500-
501-
/// Gets an instance of the type with all parameters specified but not
502-
/// checked.
503-
static QuantileQuantizedPerAxisType
504-
get(unsigned flags, Type storageType, Type quantileType, Type expressedType,
505-
ArrayRef<double> quantiles, ArrayRef<double> scales,
506-
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
507-
int64_t storageTypeMin, int64_t storageTypeMax);
508-
509-
/// Gets an instance of the type with all specified parameters checked.
510-
/// Returns a nullptr convertible type on failure.
511-
static QuantileQuantizedPerAxisType
512-
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
513-
Type storageType, Type quantileType, Type expressedType,
514-
ArrayRef<double> quantiles, ArrayRef<double> scales,
515-
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
516-
int64_t storageTypeMin, int64_t storageTypeMax);
517-
518-
/// Verifies construction invariants and issues errors/warnings.
519-
static LogicalResult
520-
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
521-
Type storageType, Type quantileType, Type expressedType,
522-
ArrayRef<double> quantiles, ArrayRef<double> scales,
523-
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
524-
int64_t storageTypeMin, int64_t storageTypeMax);
525-
526-
static bool classof(mlir::Type type);
527-
528-
/// Gets the quantileType
529-
Type getQuantileType() const;
530-
531-
/// Gets the quantileType bit width
532-
unsigned getQuantileTypeIntegralWidth() const;
533-
534-
/// Gets the quantile values
535-
ArrayRef<double> getQuantiles() const;
536-
537-
/// Fixed point values are real numbers divided by a scale.
538-
/// Currently, only signed storage types are treated as fixed point.
539-
/// A fixed point value can be obtained from an affine value by subtracting
540-
/// the zeroPoint.
541-
/// In the future, this may be explicit versus implied by type and zeroPoint.
542-
bool isFixedPoint() const {
543-
return isSigned() && !llvm::is_contained(getZeroPoints(), 0);
544-
}
545-
};
546-
547410
/// Represents sub-channel (also known as blockwise quantization).
548411
///
549412
/// Syntax synopsis:
@@ -557,7 +420,7 @@ class QuantileQuantizedPerAxisType
557420
/// ScaleZeroList ::= ScaleZero (',' ScaleZero)*
558421
/// ScaleZero ::= Scale (':' ZeroPoint)?
559422
///
560-
/// StorageType: 'i'|'u' NumBits
423+
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
561424
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
562425
/// AxisSpec: An integer value
563426
/// BlockSizeSpec: An integer value
@@ -674,6 +537,145 @@ class UniformQuantizedSubChannelType
674537
const SmallVector<std::pair<int32_t, int64_t>> getBlockSizeInfo() const;
675538
};
676539

540+
/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a
541+
/// look up table array of quantile values. The type of the data in the look up
542+
/// table is determined by the quantileType member: supported quantileType types
543+
/// are integer/unsigned/f4/hf8/bf8/f16/bf16/f32/f64.
544+
///
545+
/// Syntax synopsis:
546+
/// Per-layer, all parameters expressed:
547+
/// !quant<quantile[StorageType:QuantileType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
548+
/// Per-layer, optional parameters omitted:
549+
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
550+
///
551+
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
552+
/// QuantileType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'f16', 'bf16', 'f32',
553+
/// 'f64' ExpressedType: 'f16', 'f32', 'bf16', 'f64' Quantiles: Quantile+
554+
/// Quantile: A legal double value
555+
/// Scale: A legal double value
556+
/// ZeroPoint: An integer value
557+
class QuantileQuantizedType
558+
: public Type::TypeBase<QuantileQuantizedType, UniformQuantizedType,
559+
detail::QuantileQuantizedTypeStorage> {
560+
public:
561+
using Base::Base;
562+
using Base::getChecked;
563+
564+
static constexpr StringLiteral name = "quant.quantile";
565+
566+
/// Gets an instance of the type with all parameters specified but not
567+
/// checked.
568+
static QuantileQuantizedType get(unsigned flags, Type storageType,
569+
Type quantileType, Type expressedType,
570+
ArrayRef<double> quantiles, double scale,
571+
int64_t zeroPoint, int64_t storageTypeMin,
572+
int64_t storageTypeMax);
573+
574+
static QuantileQuantizedType
575+
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
576+
Type storageType, Type quantileType, Type expressedType,
577+
ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
578+
int64_t storageTypeMin, int64_t storageTypeMax);
579+
580+
/// Verifies construction invariants and issues errors/warnings.
581+
static LogicalResult
582+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
583+
Type storageType, Type quantileType, Type expressedType,
584+
ArrayRef<double> quantiles, double scale, int64_t zeroPoint,
585+
int64_t storageTypeMin, int64_t storageTypeMax);
586+
587+
static bool classof(mlir::Type type);
588+
589+
/// Gets the quantileType
590+
Type getQuantileType() const;
591+
592+
/// Gets the quantileType bit width
593+
unsigned getQuantileTypeIntegralWidth() const;
594+
595+
/// Gets the quantile values
596+
ArrayRef<double> getQuantiles() const;
597+
598+
// Fixed point values are real numbers divided by a scale.
599+
// Currently, only signed storage types are treated as fixed point.
600+
// A fixed point value can be obtained from an affine value by subtracting
601+
// the zeroPoint.
602+
// In the future, this may be explicit versus implied by type and zeroPoint.
603+
bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
604+
};
605+
606+
/// Represents per-axis QuantileQuantizedType (also known as per-channel
607+
/// quantization). The type of the data in the look up table is determined by
608+
/// the quantileType member: supported quantileType types are
609+
/// integer/unsigned/f4/hf8/bf8/f16/bf16/f32/f64.
610+
///
611+
/// Syntax synopsis:
612+
/// Per-axis, all parameters expressed:
613+
/// !quant<quantile[StorageType:QuantileType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
614+
/// Per-axis, optional parameters omitted:
615+
/// !quant<quantile[StorageType:QuantileType]{Quantiles}:{Scale}>
616+
///
617+
/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8'
618+
/// QuantileType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'f16', 'bf16', 'f32',
619+
/// 'f64' ExpressedType: 'f16', 'f32', 'bf16', 'f64' QuantizedDim: An integer
620+
/// value Quantiles: Quantile+ Quantile: A legal double value QuantParams:
621+
/// (Scale ':' ZeroPoint)+ Scale: A legal double value ZeroPoint: An integer
622+
/// value
623+
class QuantileQuantizedPerAxisType
624+
: public Type::TypeBase<QuantileQuantizedPerAxisType,
625+
UniformQuantizedPerAxisType,
626+
detail::QuantileQuantizedPerAxisTypeStorage> {
627+
public:
628+
using Base::Base;
629+
using Base::getChecked;
630+
631+
static constexpr StringLiteral name = "quant.quantile_per_axis";
632+
633+
/// Gets an instance of the type with all parameters specified but not
634+
/// checked.
635+
static QuantileQuantizedPerAxisType
636+
get(unsigned flags, Type storageType, Type quantileType, Type expressedType,
637+
ArrayRef<double> quantiles, ArrayRef<double> scales,
638+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
639+
int64_t storageTypeMin, int64_t storageTypeMax);
640+
641+
/// Gets an instance of the type with all specified parameters checked.
642+
/// Returns a nullptr convertible type on failure.
643+
static QuantileQuantizedPerAxisType
644+
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
645+
Type storageType, Type quantileType, Type expressedType,
646+
ArrayRef<double> quantiles, ArrayRef<double> scales,
647+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
648+
int64_t storageTypeMin, int64_t storageTypeMax);
649+
650+
/// Verifies construction invariants and issues errors/warnings.
651+
static LogicalResult
652+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
653+
Type storageType, Type quantileType, Type expressedType,
654+
ArrayRef<double> quantiles, ArrayRef<double> scales,
655+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
656+
int64_t storageTypeMin, int64_t storageTypeMax);
657+
658+
static bool classof(mlir::Type type);
659+
660+
/// Gets the quantileType
661+
Type getQuantileType() const;
662+
663+
/// Gets the quantileType bit width
664+
unsigned getQuantileTypeIntegralWidth() const;
665+
666+
/// Gets the quantile values
667+
ArrayRef<double> getQuantiles() const;
668+
669+
/// Fixed point values are real numbers divided by a scale.
670+
/// Currently, only signed storage types are treated as fixed point.
671+
/// A fixed point value can be obtained from an affine value by subtracting
672+
/// the zeroPoint.
673+
/// In the future, this may be explicit versus implied by type and zeroPoint.
674+
bool isFixedPoint() const {
675+
return isSigned() && !llvm::is_contained(getZeroPoints(), 0);
676+
}
677+
};
678+
677679
/// A quantized type that infers its range from given min/max values.
678680
///
679681
/// Typical syntax:

0 commit comments

Comments
 (0)