Skip to content

Commit 8f225ad

Browse files
author
hongwen
committed
Feat:[riscv] Add rvv support to cpu/kernels.cc
Signed-off-by: lyd1992 <[email protected]>
1 parent 617405f commit 8f225ad

File tree

7 files changed

+367
-22
lines changed

7 files changed

+367
-22
lines changed

CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,14 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64)|(amd64)|(AMD64)")
252252
add_subdirectory(third_party/cpu_features EXCLUDE_FROM_ALL)
253253
set(BUILD_SHARED_LIBS "${BUILD_SHARED_LIBS_SAVED}")
254254
list(APPEND LIBRARIES cpu_features)
255+
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64")
256+
add_definitions(-DCT2_WITH_RVV)
257+
set(CT2_BUILD_ARCH "riscv64")
258+
message(STATUS "Target architecture is RISC-V with Vector extension")
255259
endif()
256260

261+
message(STATUS "Current CT2_BUILD_ARCH is: ${CT2_BUILD_ARCH}")
262+
257263
if(ENABLE_CPU_DISPATCH)
258264
message(STATUS "Compiling for multiple CPU ISA and enabling runtime dispatch")
259265
add_definitions(-DCT2_WITH_CPU_DISPATCH)
@@ -269,6 +275,9 @@ if(ENABLE_CPU_DISPATCH)
269275
endif()
270276
elseif(CT2_BUILD_ARCH STREQUAL "arm64")
271277
ct2_compile_kernels_for_isa(neon "-DUSE_NEON")
278+
elseif(CT2_BUILD_ARCH STREQUAL "riscv64")
279+
ct2_compile_kernels_for_isa(rvv "-march=rv64gcv")
280+
message(STATUS "Current CT2_BUILD_ARCH is: ${CT2_BUILD_ARCH}")
272281
endif()
273282
endif()
274283

src/cpu/cpu_info.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,22 @@ namespace ctranslate2 {
5858
}
5959
}
6060

61+
#elif defined(CT2_WITH_RVV)
62+
63+
namespace ctranslate2 {
64+
namespace cpu {
65+
66+
const char* cpu_vendor() {
67+
return "RVV";
68+
}
69+
70+
bool cpu_supports_rvv() {
71+
return true;
72+
}
73+
74+
}
75+
}
76+
77+
78+
6179
#endif

src/cpu/cpu_info.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ namespace ctranslate2 {
1414
bool cpu_supports_avx512();
1515
#elif defined(CT2_ARM64_BUILD)
1616
bool cpu_supports_neon();
17+
#elif defined(CT2_WITH_RVV)
18+
bool cpu_supports_rvv();
1719
#endif
1820

1921
}

src/cpu/cpu_isa.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ namespace ctranslate2 {
3535
#elif defined(CT2_ARM64_BUILD)
3636
case CpuIsa::NEON:
3737
return "NEON";
38+
#elif defined(CT2_WITH_RVV)
39+
case CpuIsa::RVV:
40+
return "RVV";
3841
#endif
3942
default:
4043
return "GENERIC";
@@ -54,6 +57,9 @@ namespace ctranslate2 {
5457
#elif defined(CT2_ARM64_BUILD)
5558
if (env_isa == "NEON")
5659
return try_isa(env_isa, CpuIsa::NEON, cpu_supports_neon());
60+
#elif defined(CT2_WITH_RVV)
61+
if (env_isa == "RVV")
62+
return try_isa(env_isa, CpuIsa::RVV, cpu_supports_rvv());
5763
#endif
5864
if (env_isa == "GENERIC")
5965
return CpuIsa::GENERIC;
@@ -71,6 +77,9 @@ namespace ctranslate2 {
7177
# elif defined(CT2_ARM64_BUILD)
7278
if (cpu_supports_neon())
7379
return CpuIsa::NEON;
80+
# elif defined(CT2_WITH_RVV)
81+
if (cpu_supports_rvv())
82+
return CpuIsa::RVV;
7483
# endif
7584
#endif
7685

src/cpu/cpu_isa.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ namespace ctranslate2 {
1313
AVX512,
1414
#elif defined(CT2_ARM64_BUILD)
1515
NEON,
16+
#elif defined(CT2_WITH_RVV)
17+
RVV,
1618
#endif
1719
};
1820

@@ -46,6 +48,7 @@ namespace ctranslate2 {
4648
CPU_ISA_CASE(cpu::CpuIsa::AVX512, SINGLE_ARG(STMTS)) \
4749
CPU_ISA_CASE(cpu::CpuIsa::AVX2, SINGLE_ARG(STMTS)) \
4850
CPU_ISA_CASE(cpu::CpuIsa::AVX, SINGLE_ARG(STMTS)) \
51+
CPU_ISA_CASE(cpu::CpuIsa::RVV, SINGLE_ARG(STMTS)) \
4952
CPU_ISA_DEFAULT(cpu::CpuIsa::GENERIC, SINGLE_ARG(STMTS)) \
5053
}
5154
#elif defined(CT2_ARM64_BUILD)
@@ -54,6 +57,12 @@ namespace ctranslate2 {
5457
CPU_ISA_CASE(cpu::CpuIsa::NEON, SINGLE_ARG(STMTS)) \
5558
CPU_ISA_DEFAULT(cpu::CpuIsa::GENERIC, SINGLE_ARG(STMTS)) \
5659
}
60+
#elif defined(CT2_WITH_RVV)
61+
# define CPU_ISA_DISPATCH(STMTS) \
62+
switch (cpu::get_cpu_isa()) { \
63+
CPU_ISA_CASE(cpu::CpuIsa::RVV, SINGLE_ARG(STMTS)) \
64+
CPU_ISA_DEFAULT(cpu::CpuIsa::GENERIC, SINGLE_ARG(STMTS)) \
65+
}
5766
#endif
5867
#elif defined(__AVX512F__)
5968
# define CPU_ISA_DISPATCH(STMTS) \
@@ -75,6 +84,11 @@ namespace ctranslate2 {
7584
switch (cpu::get_cpu_isa()) { \
7685
CPU_ISA_DEFAULT(cpu::CpuIsa::NEON, SINGLE_ARG(STMTS)) \
7786
}
87+
#elif defined(__riscv_vector)
88+
# define CPU_ISA_DISPATCH(STMTS) \
89+
switch (cpu::get_cpu_isa()) { \
90+
CPU_ISA_DEFAULT(cpu::CpuIsa::RVV, SINGLE_ARG(STMTS)) \
91+
}
7892
#else
7993
# define CPU_ISA_DISPATCH(STMTS) \
8094
switch (cpu::get_cpu_isa()) { \

src/cpu/kernels.cc

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
#elif (defined(__ARM_NEON) && !defined(CT2_WITH_CPU_DISPATCH)) || defined(USE_NEON)
1515
# define TARGET_ISA CpuIsa::NEON
1616
# include "cpu/vec_neon.h"
17+
#elif (defined(CT2_WITH_RVV) && defined(__riscv_vector))
18+
# define USE_RVV
19+
# define TARGET_ISA CpuIsa::RVV
20+
# include "cpu/vec_rvv.h"
1721
#else
1822
# define TARGET_ISA CpuIsa::GENERIC
1923
# include "cpu/vec.h"
@@ -213,7 +217,7 @@ namespace ctranslate2 {
213217

214218
template<>
215219
void exp<TARGET_ISA>(const float* x, float* y, dim_t size) {
216-
vectorized_unary_transform<TARGET_ISA>(x, y, size, Vec<float, TARGET_ISA>::exp);
220+
vectorized_unary_transform<TARGET_ISA>(x, y, size, Vec<float, TARGET_ISA>::exp);
217221
}
218222

219223
template<>
@@ -263,11 +267,20 @@ namespace ctranslate2 {
263267

264268
template <CpuIsa ISA, typename T>
265269
void add(T a, const T* x, T* y, dim_t size) {
270+
#ifdef USE_RVV
271+
T a_copy = a;
272+
vectorized_unary_transform<ISA>(x, y, size,
273+
[a_copy](vec_type<T, ISA> v) {
274+
auto vec_a = Vec<T, ISA>::load(a_copy);
275+
return Vec<T, ISA>::add(v, vec_a);
276+
});
277+
#else
266278
auto vec_a = Vec<T, ISA>::load(a);
267279
vectorized_unary_transform<ISA>(x, y, size,
268-
[vec_a](vec_type<T, ISA> v) {
269-
return Vec<T, ISA>::add(v, vec_a);
270-
});
280+
[vec_a](vec_type<T, ISA> v) {
281+
return Vec<T, ISA>::add(v, vec_a);
282+
});
283+
#endif
271284
}
272285

273286
template <CpuIsa ISA, typename T>
@@ -282,11 +295,20 @@ namespace ctranslate2 {
282295

283296
template <CpuIsa ISA, typename T>
284297
void mul(T a, const T* x, T* y, dim_t size) {
298+
#ifdef USE_RVV
299+
T a_copy = a;
300+
vectorized_unary_transform<ISA>(x, y, size,
301+
[a_copy](vec_type<T, ISA> v) {
302+
auto vec_a = Vec<T, ISA>::load(a_copy);
303+
return Vec<T, ISA>::mul(v, vec_a);
304+
});
305+
#else
285306
auto vec_a = Vec<T, ISA>::load(a);
286307
vectorized_unary_transform<ISA>(x, y, size,
287-
[vec_a](vec_type<T, ISA> v) {
288-
return Vec<T, ISA>::mul(v, vec_a);
289-
});
308+
[vec_a](vec_type<T, ISA> v) {
309+
return Vec<T, ISA>::mul(v, vec_a);
310+
});
311+
#endif
290312
}
291313

292314
template <CpuIsa ISA, typename T>
@@ -296,11 +318,20 @@ namespace ctranslate2 {
296318

297319
template <CpuIsa ISA, typename T>
298320
void max(T a, const T* x, T* y, dim_t size) {
321+
#ifdef USE_RVV
322+
T a_copy = a;
323+
vectorized_unary_transform<ISA>(x, y, size,
324+
[a_copy](vec_type<T, ISA> v) {
325+
auto vec_a = Vec<T, ISA>::load(a_copy);
326+
return Vec<T, ISA>::max(v, vec_a);
327+
});
328+
#else
299329
auto vec_a = Vec<T, ISA>::load(a);
300330
vectorized_unary_transform<ISA>(x, y, size,
301-
[vec_a](vec_type<T, ISA> v) {
302-
return Vec<T, ISA>::max(v, vec_a);
303-
});
331+
[vec_a](vec_type<T, ISA> v) {
332+
return Vec<T, ISA>::max(v, vec_a);
333+
});
334+
#endif
304335
}
305336

306337
template <CpuIsa ISA, typename T>
@@ -310,11 +341,20 @@ namespace ctranslate2 {
310341

311342
template <CpuIsa ISA, typename T>
312343
void min(T a, const T* x, T* y, dim_t size) {
344+
#ifdef USE_RVV
345+
T a_copy = a;
346+
vectorized_unary_transform<ISA>(x, y, size,
347+
[a_copy](vec_type<T, ISA> v) {
348+
auto vec_a = Vec<T, ISA>::load(a_copy);
349+
return Vec<T, ISA>::min(v, vec_a);
350+
});
351+
#else
313352
auto vec_a = Vec<T, ISA>::load(a);
314353
vectorized_unary_transform<ISA>(x, y, size,
315-
[vec_a](vec_type<T, ISA> v) {
316-
return Vec<T, ISA>::min(v, vec_a);
317-
});
354+
[vec_a](vec_type<T, ISA> v) {
355+
return Vec<T, ISA>::min(v, vec_a);
356+
});
357+
#endif
318358
}
319359

320360
template <CpuIsa ISA, typename T>
@@ -349,6 +389,7 @@ namespace ctranslate2 {
349389
static_cast<T>(0),
350390
Vec<T, ISA>::abs,
351391
Vec<T, ISA>::max,
392+
352393
Vec<T, ISA>::reduce_max,
353394
Vec<T>::abs,
354395
Vec<T>::max);
@@ -377,14 +418,22 @@ namespace ctranslate2 {
377418
using VecType = Vec<float, TARGET_ISA>;
378419

379420
const auto x_max = reduce_max<TARGET_ISA>(x, size);
380-
const auto vec_x_max = VecType::load(x_max);
381421

382-
const auto scalar_exp_func = [x_max](vec_type<float> v) {
383-
return Vec<float>::exp(Vec<float>::sub(v, x_max));
422+
const auto scalar_exp_func = [x_max](float v) {
423+
return std::exp(v - x_max);
384424
};
385-
const auto vec_exp_func = [vec_x_max](vec_type<float, TARGET_ISA> v) {
425+
#ifdef USE_RVV
426+
float x_max_copy = x_max;
427+
auto vec_exp_func = [x_max_copy](vec_type<float, TARGET_ISA> v) {
428+
auto vec_x_max = VecType::load(x_max_copy);
386429
return VecType::exp(VecType::sub(v, vec_x_max));
387430
};
431+
#else
432+
const auto vec_x_max = VecType::load(x_max);
433+
auto vec_exp_func = [vec_x_max](vec_type<float, TARGET_ISA> v) {
434+
return VecType::exp(VecType::sub(v, vec_x_max));
435+
};
436+
#endif
388437

389438
const auto exp_sum = vectorized_map_reduce_all<TARGET_ISA>(
390439
x,
@@ -429,14 +478,21 @@ namespace ctranslate2 {
429478
}
430479

431480
const auto x_max = reduce_max<TARGET_ISA>(x, size);
432-
const auto vec_x_max = VecType::load(x_max);
433-
434-
const auto scalar_exp_func = [x_max](vec_type<float> v) {
435-
return Vec<float>::exp(Vec<float>::sub(v, x_max));
481+
const auto scalar_exp_func = [x_max](float v) {
482+
return std::exp(v - x_max);
483+
};
484+
#ifdef USE_RVV
485+
float x_max_copy = x_max;
486+
auto vec_exp_func = [x_max_copy](vec_type<float, TARGET_ISA> v) {
487+
auto vec_x_max = VecType::load(x_max_copy);
488+
return VecType::exp(VecType::sub(v, vec_x_max));
436489
};
437-
const auto vec_exp_func = [vec_x_max](vec_type<float, TARGET_ISA> v) {
490+
#else
491+
const auto vec_x_max = VecType::load(x_max);
492+
auto vec_exp_func = [vec_x_max](vec_type<float, TARGET_ISA> v) {
438493
return VecType::exp(VecType::sub(v, vec_x_max));
439494
};
495+
#endif
440496

441497
if (log) {
442498
const auto exp_sum = vectorized_map_reduce_all<TARGET_ISA>(

0 commit comments

Comments
 (0)