Skip to content

Commit f201439

Browse files
committed
修复多次推理结果出错问题
1 parent 2ed4d69 commit f201439

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

src/ax_translate.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,26 @@ int ax_translate_deinit(ax_translate_handle_t handle)
8686
int ax_translate(ax_translate_handle_t handle, ax_translate_io_t *io)
8787
{
8888
ax_translate_t *translate = (ax_translate_t *)handle;
89+
if (translate == nullptr)
90+
{
91+
printf("translate is null\n");
92+
return -1;
93+
}
94+
for (size_t i = 0; i < translate->m_runner->get_num_inputs(); i++)
95+
{
96+
memset(translate->m_runner->get_input(i).pVirAddr, 0, translate->m_runner->get_input(i).nSize);
97+
}
98+
8999
std::vector<int> output_ids;
90100
std::vector<int> input_ids;
91101
std::vector<int> mask;
92102
int len = translate->tokenizer.encode(io->input, MAX_LENGTH, false, input_ids, &mask);
103+
// printf("len: %d [", len);
104+
// for (int i = 0; i < len; i++)
105+
// {
106+
// printf("%d ", input_ids[i]);
107+
// }
108+
// printf("]\n");
93109

94110
std::vector<int> decoder_input_ids(MAX_LENGTH, translate->tokenizer.get_pad_id());
95111
std::vector<int> decoder_attention_mask(MAX_LENGTH, 0);

tests/test_translate.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "cmdline.hpp"
33
#include <cstring>
44
#include <cstdio>
5+
#include <fstream>
56

67
class Translator
78
{
@@ -95,7 +96,7 @@ int main(int argc, char *argv[])
9596
cmdline::parser parser;
9697
parser.add<std::string>("model", 'm', " model path for axmodel)", true);
9798
parser.add<std::string>("tokenizer_dir", 'k', "tokenizer dir", true);
98-
parser.add<std::string>("text", 't', "text to translate", true);
99+
parser.add<std::string>("text", 't', "text or .txt file to translate", true);
99100
parser.parse_check(argc, argv);
100101

101102
std::string model_path = parser.get<std::string>("model");
@@ -109,8 +110,26 @@ int main(int argc, char *argv[])
109110
printf("init translator failed\n");
110111
return -1;
111112
}
112-
113-
std::string output = translator.Translate(text);
114-
printf("output: %s\n", output.c_str());
113+
if (text.find(".txt") != std::string::npos)
114+
{
115+
std::ifstream ifs(text);
116+
if (!ifs.is_open())
117+
{
118+
printf("open file failed\n");
119+
return -1;
120+
}
121+
std::string line;
122+
while (std::getline(ifs, line))
123+
{
124+
std::string output = translator.Translate(line);
125+
printf("input: %s, output: %s\n", line.c_str(), output.c_str());
126+
}
127+
ifs.close();
128+
}
129+
else
130+
{
131+
std::string output = translator.Translate(text);
132+
printf("output: %s\n", output.c_str());
133+
}
115134
return 0;
116135
}

0 commit comments

Comments
 (0)