|
1 | 1 | #include "linear.h" |
2 | 2 | #include "kernels/linear_kernels.h" |
3 | 3 | #include "local-execution/task_argument_accessor.h" |
4 | | -#include "op-attrs/ff_dim.h" |
| 4 | +#include "op-attrs/ff_dim_t.h" |
5 | 5 | #include "op-attrs/get_output_shapes.h" |
6 | 6 | #include "utils/exception.h" |
7 | 7 | #include "utils/hash-utils.h" |
@@ -66,8 +66,8 @@ static DeviceSpecificDeviceStates |
66 | 66 | auto input = acc.get_tensor<Permissions::RO>(INPUT); |
67 | 67 | auto weight = acc.get_tensor<Permissions::RO>(WEIGHT); |
68 | 68 | auto output = acc.get_tensor<Permissions::WO>(OUTPUT); |
69 | | - int out_dim = output.shape.at(ff_dim_t{0}); |
70 | | - int batch_size = output.shape.at(ff_dim_t{1}); |
| 69 | + int out_dim = output.shape.at(ff_dim_t{nonnegative_int{0}}); |
| 70 | + int batch_size = output.shape.at(ff_dim_t{nonnegative_int{1}}); |
71 | 71 |
|
72 | 72 | float *one_ptr; |
73 | 73 |
|
@@ -96,8 +96,8 @@ static std::optional<float> forward_task_impl(TaskArgumentAccessor const &acc) { |
96 | 96 | ProfilingSettings profiling = acc.get_argument<ProfilingSettings>(PROFILING); |
97 | 97 | auto attrs = acc.get_argument<LinearAttrs>(ATTRS); |
98 | 98 |
|
99 | | - int in_dim = input.shape.at(ff_dim_t{0}) + 1; |
100 | | - int out_dim = output.shape.at(ff_dim_t{0}) + 1; |
| 99 | + int in_dim = input.shape.at(ff_dim_t{nonnegative_int{0}}) + 1; |
| 100 | + int out_dim = output.shape.at(ff_dim_t{nonnegative_int{0}}) + 1; |
101 | 101 | int batch_size = output.shape.get_volume() / out_dim; |
102 | 102 |
|
103 | 103 | float const *bias_ptr = NULL; |
@@ -140,8 +140,8 @@ static std::optional<float> |
140 | 140 | bias_ptr = bias.get_float_ptr(); |
141 | 141 | } |
142 | 142 |
|
143 | | - int in_dim = input.shape.at(ff_dim_t{0}) + 1; |
144 | | - int out_dim = output.shape.at(ff_dim_t{0}) + 1; |
| 143 | + int in_dim = input.shape.at(ff_dim_t{nonnegative_int{0}}) + 1; |
| 144 | + int out_dim = output.shape.at(ff_dim_t{nonnegative_int{0}}) + 1; |
145 | 145 | int batch_size = output.shape.get_volume() / out_dim; |
146 | 146 |
|
147 | 147 | return profile(backward_kernel, |
|
0 commit comments