Skip to content

Commit 7286459

Browse files
committed
fix: 整理和改正 attention 构造问题
Signed-off-by: YdrMaster <[email protected]>
1 parent 20a34ac commit 7286459

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

src/05computation/include/computation/operators/attention.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
namespace refactor::computation {
77

88
struct Attention final : public Operator {
9-
dim_t maxSeqLen;
109

11-
constexpr Attention(decltype(maxSeqLen) maxSeqLen_) noexcept
12-
: Operator(), maxSeqLen(maxSeqLen_) {}
10+
constexpr Attention() noexcept = default;
1311

1412
static size_t typeId() noexcept;
1513
size_t opTypeId() const noexcept final;
1614
std::string_view name() const noexcept final;
15+
kernel::CollectorBox candidateKernels(Target) const final;
16+
std::string serialize() const noexcept final;
1717
};
1818

1919
}// namespace refactor::computation

src/05computation/src/operators/attention.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "computation/operators/attention.h"
2+
#include "kernel/collectors/attention.h"
23

34
namespace refactor::computation {
45
using Op = Attention;
@@ -9,5 +10,12 @@ namespace refactor::computation {
910
}
1011
auto Op::opTypeId() const noexcept -> size_t { return typeId(); }
1112
auto Op::name() const noexcept -> std::string_view { return "Attention"; }
13+
auto Op::candidateKernels(Target target) const -> kernel::CollectorBox {
14+
using Collector_ = kernel::AttentionCollector;
15+
return std::make_unique<Collector_>(target);
16+
}
17+
auto Op::serialize() const noexcept -> std::string {
18+
return "Attention()";
19+
}
1220

1321
}// namespace refactor::computation

src/08-01llm/src/operators/attention.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace refactor::llm {
99
: Operator(), maxSeqLen(maxSeqLen_) {}
1010

1111
auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
12-
auto maxSeqLen = attributes.getOrInsert("max_seq_len", {0}).float_();
12+
auto maxSeqLen = attributes.getOrInsert("max_seq_len", {0}).int_();
1313
return OpBox(std::make_unique<Op>(maxSeqLen));
1414
}
1515
auto Op::typeId() -> size_t {
@@ -129,7 +129,7 @@ namespace refactor::llm {
129129

130130
auto Op::lower(TensorRefs) const -> computation::OpBox {
131131
using Op_ = computation::Attention;
132-
return std::make_unique<Op_>(maxSeqLen);
132+
return std::make_unique<Op_>();
133133
}
134134

135135
}// namespace refactor::llm

0 commit comments

Comments
 (0)