File tree Expand file tree Collapse file tree 3 files changed +13
-5
lines changed
include/computation/operators Expand file tree Collapse file tree 3 files changed +13
-5
lines changed Original file line number Diff line number Diff line change 66namespace refactor ::computation {
77
88 struct Attention final : public Operator {
9- dim_t maxSeqLen;
109
11- constexpr Attention (decltype (maxSeqLen) maxSeqLen_) noexcept
12- : Operator(), maxSeqLen(maxSeqLen_) {}
10+ constexpr Attention () noexcept = default;
1311
1412 static size_t typeId () noexcept ;
1513 size_t opTypeId () const noexcept final ;
1614 std::string_view name () const noexcept final ;
15+ kernel::CollectorBox candidateKernels (Target) const final ;
16+ std::string serialize () const noexcept final ;
1717 };
1818
1919}// namespace refactor::computation
Original file line number Diff line number Diff line change 11#include " computation/operators/attention.h"
2+ #include " kernel/collectors/attention.h"
23
34namespace refactor ::computation {
45 using Op = Attention;
@@ -9,5 +10,12 @@ namespace refactor::computation {
910 }
1011 auto Op::opTypeId () const noexcept -> size_t { return typeId (); }
1112 auto Op::name () const noexcept -> std::string_view { return " Attention" ; }
13+ auto Op::candidateKernels (Target target) const -> kernel::CollectorBox {
14+ using Collector_ = kernel::AttentionCollector;
15+ return std::make_unique<Collector_>(target);
16+ }
17+ auto Op::serialize () const noexcept -> std::string {
18+ return " Attention()" ;
19+ }
1220
1321}// namespace refactor::computation
Original file line number Diff line number Diff line change @@ -9,7 +9,7 @@ namespace refactor::llm {
99 : Operator(), maxSeqLen(maxSeqLen_) {}
1010
1111 auto Op::build (ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
12- auto maxSeqLen = attributes.getOrInsert (" max_seq_len" , {0 }).float_ ();
12+ auto maxSeqLen = attributes.getOrInsert (" max_seq_len" , {0 }).int_ ();
1313 return OpBox (std::make_unique<Op>(maxSeqLen));
1414 }
1515 auto Op::typeId () -> size_t {
@@ -129,7 +129,7 @@ namespace refactor::llm {
129129
130130 auto Op::lower (TensorRefs) const -> computation::OpBox {
131131 using Op_ = computation::Attention;
132- return std::make_unique<Op_>(maxSeqLen );
132+ return std::make_unique<Op_>();
133133 }
134134
135135}// namespace refactor::llm
You can’t perform that action at this time.
0 commit comments