@@ -49,49 +49,16 @@ void VTanh(const T* x, T* y, int n) {
49
49
}
50
50
51
51
void Softmax (const T* x, T* y, int n, int bs) {
52
- typename XRNTuples<T>::func_type compute_hmax{nullptr };
53
- typename XRNTuples<T>::func_type compute_hsum{nullptr };
54
- typename AXYNTuples<T>::func_type compute_vscal{nullptr };
55
- typename AXYNTuples<T>::func_type compute_vaddbias{nullptr };
56
- typename XYNTuples<T>::func_type compute_vexp{nullptr };
57
-
58
- if (!KernelFuncsCache<kHMax , XRNTuples<T>>::Instance ().Has (n)) {
59
- compute_hmax = Get<kHMax , XRNTuples<T>, platform::CPUPlace>(n);
60
- KernelFuncsCache<kHMax , XRNTuples<T>>::Instance ().Insert (n, compute_hmax);
61
- } else {
62
- compute_hmax = KernelFuncsCache<kHMax , XRNTuples<T>>::Instance ().At (n);
63
- }
64
-
65
- if (!KernelFuncsCache<kHSum , XRNTuples<T>>::Instance ().Has (n)) {
66
- compute_hsum = Get<kHSum , XRNTuples<T>, platform::CPUPlace>(n);
67
- KernelFuncsCache<kHSum , XRNTuples<T>>::Instance ().Insert (n, compute_hsum);
68
- } else {
69
- compute_hsum = KernelFuncsCache<kHSum , XRNTuples<T>>::Instance ().At (n);
70
- }
71
-
72
- if (!KernelFuncsCache<kVScal , AXYNTuples<T>>::Instance ().Has (n)) {
73
- compute_vscal = Get<kVScal , AXYNTuples<T>, platform::CPUPlace>(n);
74
- KernelFuncsCache<kVScal , AXYNTuples<T>>::Instance ().Insert (n,
75
- compute_vscal);
76
- } else {
77
- compute_vscal = KernelFuncsCache<kVScal , AXYNTuples<T>>::Instance ().At (n);
78
- }
79
-
80
- if (!KernelFuncsCache<kVAddBias , AXYNTuples<T>>::Instance ().Has (n)) {
81
- compute_vaddbias = Get<kVAddBias , AXYNTuples<T>, platform::CPUPlace>(n);
82
- KernelFuncsCache<kVAddBias , AXYNTuples<T>>::Instance ().Insert (
83
- n, compute_vaddbias);
84
- } else {
85
- compute_vaddbias =
86
- KernelFuncsCache<kVAddBias , AXYNTuples<T>>::Instance ().At (n);
87
- }
88
-
89
- if (!KernelFuncsCache<kVExp , XYNTuples<T>>::Instance ().Has (n)) {
90
- compute_vexp = Get<KernelType::kVExp , XYNTuples<T>, platform::CPUPlace>(n);
91
- KernelFuncsCache<kVExp , XYNTuples<T>>::Instance ().Insert (n, compute_vexp);
92
- } else {
93
- compute_vexp = KernelFuncsCache<kVExp , XYNTuples<T>>::Instance ().At (n);
94
- }
52
+ auto compute_hmax =
53
+ KernelFuncs<kHMax , XRNTuples<T>, platform::CPUPlace>::Cache ().At (n);
54
+ auto compute_hsum =
55
+ KernelFuncs<kHSum , XRNTuples<T>, platform::CPUPlace>::Cache ().At (n);
56
+ auto compute_vscal =
57
+ KernelFuncs<kVScal , AXYNTuples<T>, platform::CPUPlace>::Cache ().At (n);
58
+ auto compute_vaddbias =
59
+ KernelFuncs<kVAddBias , AXYNTuples<T>, platform::CPUPlace>::Cache ().At (n);
60
+ auto compute_vexp =
61
+ KernelFuncs<kVExp , XYNTuples<T>, platform::CPUPlace>::Cache ().At (n);
95
62
96
63
for (int i = 0 ; i < bs; ++i) {
97
64
T scalar;
0 commit comments