Skip to content

Commit 225ba18

Browse files
committed
Fix Attention addLayer, make cmake to work with TRT 10.14
1 parent 57c5c8d commit 225ba18

File tree

12 files changed

+143
-17
lines changed

12 files changed

+143
-17
lines changed

mlir-tensorrt/CMakePresets.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,20 @@
100100
"MLIR_TRT_ENABLE_NCCL": "OFF",
101101
"MLIR_TRT_DOWNLOAD_TENSORRT_VERSION": "$env{DOWNLOAD_TENSORRT_VERSION}"
102102
}
103+
},
104+
{
105+
"name": "python-wheel-build",
106+
"displayName": "Configuration for building the compiler/runtime Python package wheels",
107+
"generator": "Ninja",
108+
"binaryDir": "build",
109+
"inherits": "ninja-llvm",
110+
"cacheVariables": {
111+
"CMAKE_BUILD_TYPE": "Release",
112+
"LLVM_ENABLE_ASSERTIONS": "OFF",
113+
"CMAKE_PLATFORM_NO_VERSIONED_SONAME": "ON",
114+
"MLIR_TRT_ENABLE_NCCL": "OFF",
115+
"MLIR_TRT_DOWNLOAD_TENSORRT_VERSION": "$env{DOWNLOAD_TENSORRT_VERSION}"
116+
}
103117
}
104118
]
105119
}

mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ macro(configure_tensorrt_python_plugin_header)
5757
find_file(
5858
trt_python_plugin_header
5959
NAMES NvInferPythonPlugin.h plugin.h
60-
HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl
61-
PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl
60+
HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/include/impl
61+
PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/include/impl
6262
REQUIRED
6363
NO_CMAKE_PATH NO_DEFAULT_PATH
6464
NO_CACHE

mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
8080
set(ARG_VERSION "10.12.0.36")
8181
endif()
8282

83+
if(ARG_VERSION VERSION_EQUAL "10.14")
84+
set(ARG_VERSION "10.14.1.48")
85+
endif()
86+
8387
set(downloadable_versions
8488
"8.6.1.6"
8589
"9.0.1.4" "9.1.0.4" "9.2.0.5"
@@ -97,6 +101,7 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
97101
"10.8.0.43"
98102
"10.9.0.34"
99103
"10.12.0.36"
104+
"10.14.1.48"
100105
)
101106

102107
if(NOT ARG_VERSION IN_LIST downloadable_versions)
@@ -164,6 +169,8 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
164169
elseif(ARG_VERSION VERSION_GREATER 10.10
165170
AND ARG_VERSION VERSION_LESS 10.13)
166171
set(TRT_CUDA_VERSION 12.9)
172+
elseif(ARG_VERSION VERSION_GREATER 10.13)
173+
set(TRT_CUDA_VERSION 13.0)
167174
endif()
168175

169176
# Handle TRT 8 versions.

mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/dialects/test_tensorrt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def test_attributes():
4949
tensorrt.TripLimitAttr.get("kWHILE"),
5050
tensorrt.FillOperationAttr.get("kRANDOM_UNIFORM"),
5151
tensorrt.ScatterModeAttr.get("kELEMENT"),
52+
tensorrt.AttentionNormalizationOpAttr.get("kSOFTMAX"),
53+
tensorrt.DataTypeAttr.get("kFLOAT"),
5254
]:
5355
print(attr)
5456

@@ -74,3 +76,5 @@ def test_attributes():
7476
# CHECK-NEXT: #tensorrt.trip_limit<kWHILE>
7577
# CHECK-NEXT: #tensorrt.fill_operation<kRANDOM_UNIFORM>
7678
# CHECK-NEXT: #tensorrt.scatter_mode<kELEMENT>
79+
# CHECK-NEXT: #tensorrt.attention_normalization_op<kSOFTMAX>
80+
# CHECK-NEXT: #tensorrt.data_type<kFLOAT>

mlir-tensorrt/compiler/tools/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ set(LLVM_LINK_COMPONENTS
2121
add_subdirectory(mlir-tensorrt-opt)
2222
add_subdirectory(mlir-tensorrt-compiler)
2323
add_subdirectory(mlir-tensorrt-translate)
24-
add_subdirectory(mlir-tensorrt-lsp-server)
24+
# add_subdirectory(mlir-tensorrt-lsp-server)
2525
add_subdirectory(mlir-tensorrt-runner)

mlir-tensorrt/integrations/python/setup_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import subprocess
1414
import atexit
1515

16-
TENSORRT_VERSION = os.getenv("MLIR_TRT_DOWNLOAD_TENSORRT_VERSION", "10.12")
16+
TENSORRT_VERSION = os.getenv("MLIR_TRT_DOWNLOAD_TENSORRT_VERSION", "10.14")
1717

1818

1919
def log(*args):
@@ -105,8 +105,8 @@ def run_cmake_build(python_package_name: str, python_wheel_staging_dir: Path):
105105

106106
# Environment variable overrides
107107
cmake_preset = os.environ.get("MLIR_TRT_CMAKE_PRESET", "python-wheel-build")
108-
install_prefix = os.environ.get("MLIR_TRT_INSTALL_DIR", None)
109-
build_dir = os.environ.get("MLIR_TRT_BUILD_DIR", None)
108+
install_prefix = os.environ.get("MLIR_TRT_INSTALL_DIR", "./install")
109+
build_dir = os.environ.get("MLIR_TRT_BUILD_DIR", "./build")
110110
parallel_jobs = os.environ.get("MLIR_TRT_PARALLEL_JOBS", str(os.cpu_count() or 1))
111111

112112
# Additional CMake options from environment

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect-c/TensorRTAttributes.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,22 @@ DECLARE_ATTR_GETTER_FROM_STRING(ScatterMode)
188188
DECLARE_IS_ATTR(ScatterMode)
189189
DECLARE_STRING_GETTER_FROM_ATTR(ScatterMode)
190190

191+
//===----------------------------------------------------------------------===//
192+
// AttentionNormalizationOp
193+
//===----------------------------------------------------------------------===//
194+
195+
DECLARE_ATTR_GETTER_FROM_STRING(AttentionNormalizationOp)
196+
DECLARE_IS_ATTR(AttentionNormalizationOp)
197+
DECLARE_STRING_GETTER_FROM_ATTR(AttentionNormalizationOp)
198+
199+
//===----------------------------------------------------------------------===//
200+
// DataType
201+
//===----------------------------------------------------------------------===//
202+
203+
DECLARE_ATTR_GETTER_FROM_STRING(DataType)
204+
DECLARE_IS_ATTR(DataType)
205+
DECLARE_STRING_GETTER_FROM_ATTR(DataType)
206+
191207
#ifdef __cplusplus
192208
}
193209
#endif

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4485,7 +4485,8 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
44854485
indicates the position is allowed to attend. For other types, mask values
44864486
are added to BMM1 output.
44874487
- NormalizationQuantizeScale (optional): tensor of type f32, f16, or bf16
4488-
with rank 0 or 1, used for quantizing the normalization output.
4488+
with rank 0 (scalar) or 1 (1D tensor), used for quantizing the normalization output.
4489+
Required when normalization_quantize_to_type is specified.
44894490

44904491
#### Attributes:
44914492

@@ -4502,6 +4503,10 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45024503
- If normalization_quantize_to_type is specified:
45034504
* It must be kFP8 or kINT8
45044505
* normalization_quantize_scale input must be provided
4506+
- If normalization_quantize_scale is provided:
4507+
* normalization_quantize_to_type must be specified
4508+
* Element type must be f32, f16, or bf16
4509+
* Rank must be 0 (scalar) or 1 (1D tensor)
45054510
- Cannot use both mask input and causal=true simultaneously
45064511

45074512
#### Examples:
@@ -4539,7 +4544,7 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45394544
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value,
45404545
Optional<TensorRT_Tensor>:$mask,
45414546
Optional<TensorRT_RankedTensorOf<[F16, BF16, F32]>>:$normalization_quantize_scale,
4542-
OptionalAttr<TensorRT_AttentionNormalizationOpAttr>:$normalization_operation,
4547+
DefaultValuedAttr<TensorRT_AttentionNormalizationOpAttr, "tensorrt::AttentionNormalizationOp::kSOFTMAX">:$normalization_operation,
45434548
DefaultValuedAttr<BoolAttr, "false">:$causal,
45444549
DefaultValuedAttr<BoolAttr, "false">:$decomposable,
45454550
OptionalAttr<TensorRT_DataTypeAttr>:$normalization_quantize_to_type
@@ -4565,12 +4570,7 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45654570
}] # baseClassDeclaration;
45664571

45674572
let trtLayerAdd = [{
4568-
// Get normalization operation, default to kSOFTMAX
4569-
nvinfer1::AttentionNormalizationOp normOp = $normalization_operation
4570-
? *$normalization_operation
4571-
: nvinfer1::AttentionNormalizationOp::kSOFTMAX;
4572-
4573-
nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, normOp, $causal);
4573+
nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, *$normalization_operation, $causal);
45744574
if (!layer)
45754575
return failure();
45764576

@@ -4584,19 +4584,22 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45844584
}
45854585

45864586
if ($normalization_quantize_to_type) {
4587-
layer->setNormalizationQuantizeToType(*$normalization_quantize_to_type);
4587+
auto convertedDataType = ::mlir::tensorrt::convertDataTypeToNvInferEnum(*$normalization_quantize_to_type);
4588+
if (!convertedDataType)
4589+
return emitError($op->getLoc()) << "failed to convert DataType to nvinfer enum";
4590+
layer->setNormalizationQuantizeToType(*convertedDataType);
45884591
}
45894592

45904593
if (!$e.isStronglyTyped()){
45914594
FailureOr<nvinfer1::DataType> outputTrtType = getNvInferDataType($op.getLoc(),
45924595
$op.getType().getElementType());
45934596
if (failed(outputTrtType))
45944597
return failure();
4595-
layer->setOutputType(0, *outputTrtType);
45964598
}
45974599

45984600
$results.push_back(layer->getOutput(0));
4599-
$e.setMetadata(layer, $op);
4601+
// TODO: nvinfer1::IAttention does not have setMetadata API in 10.14
4602+
// layer->setMetadata($op);
46004603
}];
46014604
}
46024605

mlir-tensorrt/tensorrt/lib/Bindings/Python/DialectTensorRT.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,6 @@ PYBIND11_MODULE(_tensorrt, m) {
7777
ADD_PYTHON_ATTRIBUTE_ADAPTOR(TripLimit)
7878
ADD_PYTHON_ATTRIBUTE_ADAPTOR(FillOperation)
7979
ADD_PYTHON_ATTRIBUTE_ADAPTOR(ScatterMode)
80+
ADD_PYTHON_ATTRIBUTE_ADAPTOR(AttentionNormalizationOp)
81+
ADD_PYTHON_ATTRIBUTE_ADAPTOR(DataType)
8082
}

mlir-tensorrt/tensorrt/lib/CAPI/TensorRTAttributes.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,11 @@ DEFINE_STRING_GETTER_FROM_ATTR(FillOperation)
121121
DEFINE_ATTR_GETTER_FROM_STRING(ScatterMode)
122122
DEFINE_IS_ATTR(ScatterMode)
123123
DEFINE_STRING_GETTER_FROM_ATTR(ScatterMode)
124+
125+
DEFINE_ATTR_GETTER_FROM_STRING(AttentionNormalizationOp)
126+
DEFINE_IS_ATTR(AttentionNormalizationOp)
127+
DEFINE_STRING_GETTER_FROM_ATTR(AttentionNormalizationOp)
128+
129+
DEFINE_ATTR_GETTER_FROM_STRING(DataType)
130+
DEFINE_IS_ATTR(DataType)
131+
DEFINE_STRING_GETTER_FROM_ATTR(DataType)

0 commit comments

Comments
 (0)