Skip to content

Commit 9439059

Browse files
authored
Merge pull request #403 from guoruoqian/embedding_fix_bug
fix bugs in embedding converter
2 parents 958be30 + de269af commit 9439059

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,11 @@ auto select_registrations TRTORCH_UNUSED =
176176
{"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)",
177177
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
178178
auto embeddingTensor = args[0].ITensorOrFreeze(ctx);
179-
auto indicesTensor = args[1].ITensor();
179+
auto indicesTensor = args[1].ITensorOrFreeze(ctx);
180180
// Set datatype for indices tensor to INT32
181-
indicesTensor->setType(nvinfer1::DataType::kINT32);
181+
auto identity = ctx->net->addIdentity(*indicesTensor);
182+
identity->setOutputType(0, nvinfer1::DataType::kINT32);
183+
indicesTensor = identity->getOutput(0);
182184

183185
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
184186
auto gather_layer = ctx->net->addGather(*embeddingTensor, *indicesTensor, 0);

tests/core/conversion/converters/test_select.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ TEST(Converters, ATenEmbeddingConvertsCorrectly) {
159159
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
160160

161161
// Run TensorRT
162-
auto options_trt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kI32);
162+
auto options_trt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kFloat);
163163
auto trt_in = at::tensor({0, 1, 2}, options_trt);
164164
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
165165
auto trt = trt_results[0].reshape(jit_results[0].sizes());

0 commit comments

Comments
 (0)