Skip to content

Commit 3336e68

Browse files
Add mbart support in triton fastertransformer (#21)
* commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit
1 parent bf3fa27 commit 3336e68

File tree

14 files changed

+422
-92
lines changed

14 files changed

+422
-92
lines changed

examples/cpp/bart/bart_triton_example.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std
228228
ft::FT_CHECK(false);
229229
}
230230

231-
const size_t request_batch_size = reader.GetInteger("request", "request_batch_size");
231+
const size_t request_batch_size = 1; //reader.GetInteger("request", "request_batch_size");
232232

233233
const int start_id = reader.GetInteger("decoder", "start_id");
234234
const int end_id = reader.GetInteger("decoder", "end_id");
@@ -251,6 +251,7 @@ prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std
251251

252252
RequestParam param;
253253
param.beam_width = reader.GetInteger("request", "beam_width");
254+
// param.beam_width = 5;
254255
param.request_output_len = reader.GetInteger("request", "request_output_len");
255256
param.beam_search_diversity_rate = reader.GetFloat("request", "beam_search_diversity_rate");
256257
param.runtime_top_k = reader.GetInteger("request", "top_k");
@@ -261,7 +262,7 @@ prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std
261262
param.presence_penalty = reader.GetFloat("request", "presence_penalty", 0.0f);
262263
param.min_length = reader.GetInteger("request", "min_length", 0);
263264
param.random_seed = (unsigned long long int)0;
264-
param.start_id = start_id;
265+
param.start_id = 250025;
265266
param.end_id = end_id;
266267

267268
auto request_list =
@@ -381,10 +382,11 @@ int main(int argc, char* argv[])
381382
}
382383

383384
const int* d_output_ids = (const int*)output_tensors_lists[0].get()->at("output_ids").data;
384-
const int batch_size = output_tensors_lists[0].get()->at("output_ids").shape[0];
385+
const int batch_size = 1; // output_tensors_lists[0].get()->at("output_ids").shape[0];
385386
const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1];
386387
const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2];
387388
const int* d_input_lengths = (const int*)output_tensors_lists[0].get()->at("input_sequence_lengths").data;
389+
printf("batch_size: %d beam_width: %d seq_len: %d\n", batch_size, beam_width, seq_len);
388390
// step 6: check results
389391
if (node_id == 0) {
390392

examples/cpp/bart/config.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repetition_penalty=1.0 ; Use for sampling
1717
presence_penalty=0.0 ; Only one of repetition_penalty and presence_penalty are allowed.
1818
len_penalty=0.0
1919
beam_search_diversity_rate=0.0
20-
request_batch_size=8 # determine by the request
20+
request_batch_size=1 # determine by the request
2121
request_output_len=32 # determine by the request
2222

2323
[encoder]

examples/cpp/bart/start_ids.csv

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
0, 4154, 1231, 15674, 345, 1534, 440, 50264, 11, 1854, 2
2-
0, 4154, 1231, 15674, 345, 1534, 440, 50264, 11, 1854, 2
1+
250004, 35378, 4, 765, 398, 49782, 111, 76935, 13034, 350, 32, 2

examples/cpp/multi_gpu_gpt/gpt_example_utils.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ int read_start_ids(size_t batch_size,
4545
int i1 = 0;
4646
std::vector<int> tmp_vec;
4747
while (std::getline(lineStream, vals, ',')) {
48+
printf("vals: %s\n", vals.c_str());
4849
tmp_vec.push_back(std::stoi(vals));
4950
i1++;
5051
}
@@ -88,6 +89,7 @@ int read_start_ids(size_t batch_size,
8889
for (int j = 0; j < (int)tmp_start_ids[i].size(); j++) {
8990
v_start_ids->push_back(tmp_start_ids[i][j]);
9091
}
92+
printf("tmp_start_lengths[i]: %d\n", tmp_start_lengths[i]);
9193
v_start_lengths->push_back(tmp_start_lengths[i]);
9294
}
9395
}

examples/pytorch/bart/utils/ft_encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch.nn as nn
1717
import torch.distributed as dist
1818
import numpy as np
19+
from transformers import MBartForConditionalGeneration, BartModel
1920

2021
class FTBartEncoderWeight(object):
2122
def __init__(
@@ -246,4 +247,4 @@ def __init__(self, encoder_weight_list, lib_path, head_num, head_size, inter_siz
246247

247248
def forward(self, input, seq_len, inputs_embeds=None):
248249
output = self.encoder.forward(input, seq_len, inputs_embeds)
249-
return output
250+
return output

0 commit comments

Comments
 (0)