Skip to content

Commit 9f736e9

Browse files
committed
test(llm): 添加 attention 前端单测
Signed-off-by: YdrMaster <[email protected]>
1 parent d7bbd3b commit 9f736e9

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "../src/operators/attention.hh"
2+
#include "llm/operators.h"
3+
#include <gtest/gtest.h>
4+
5+
using namespace refactor;
6+
using namespace llm;
7+
8+
TEST(infer, AttentionNoKvCache) {
9+
llm::register_();
10+
auto batch = DimExpr("N");
11+
auto numHead = DimExpr(16);
12+
auto seqLen = DimExpr(31);
13+
auto headDim = DimExpr(64);
14+
{
15+
auto edges = Edges{
16+
{Tensor::share(DataType::FP16, Shape{batch, numHead, seqLen, headDim}, {}), ""},
17+
{Tensor::share(DataType::FP16, Shape{batch, numHead, seqLen, headDim}, {}), ""},
18+
{Tensor::share(DataType::FP16, Shape{batch, numHead, seqLen, headDim}, {}), ""},
19+
};
20+
count_t inputs[]{0, 1, 2};
21+
auto infered = Attention(0).infer(TensorRefs(edges, inputs), {true});
22+
ASSERT_TRUE(infered.isOk());
23+
auto outputs = std::move(infered.unwrap());
24+
ASSERT_EQ(outputs.size(), 1);
25+
auto y = std::move(outputs[0]);
26+
ASSERT_EQ(y->dataType, DataType::FP16);
27+
ASSERT_EQ(y->shape, edges[0].tensor->shape);
28+
}
29+
{
30+
auto edges = Edges{
31+
{Tensor::share(DataType::FP16, Shape{batch, numHead, seqLen, headDim}, {}), ""},
32+
{Tensor::share(DataType::FP16, Shape{batch, DimExpr(4), seqLen, headDim}, {}), ""},
33+
{Tensor::share(DataType::FP16, Shape{batch, DimExpr(4), seqLen, headDim}, {}), ""},
34+
};
35+
count_t inputs[]{0, 1, 2};
36+
auto infered = Attention(0).infer(TensorRefs(edges, inputs), {true});
37+
ASSERT_TRUE(infered.isOk());
38+
auto outputs = std::move(infered.unwrap());
39+
ASSERT_EQ(outputs.size(), 1);
40+
auto y = std::move(outputs[0]);
41+
ASSERT_EQ(y->dataType, DataType::FP16);
42+
ASSERT_EQ(y->shape, edges[0].tensor->shape);
43+
}
44+
}

0 commit comments

Comments
 (0)