|
3 | 3 | use co::plugin::numeric_helpers::Float; |
4 | 4 | use co::memory::MemoryType; |
5 | 5 |
|
6 | | -#[derive(Debug, Copy, Clone)] |
7 | | -#[allow(missing_docs)] |
8 | | -pub struct ConvolutionConfig; |
9 | 6 | #[derive(Debug, Copy, Clone)] |
10 | 7 | #[allow(missing_docs)] |
11 | 8 | pub struct NormalizationConfig; |
@@ -278,66 +275,12 @@ macro_rules! impl_ops_tanh_for { |
278 | 275 | ); |
279 | 276 | } |
280 | 277 |
|
281 | | -#[macro_export] |
282 | | -macro_rules! impl_ops_convolution_for { |
283 | | - ($t:ident, $b:ty) => ( |
284 | | - impl ::plugin::Convolution<$t> for $b { |
285 | | - fn new_convolution_config( |
286 | | - &self, |
287 | | - src: &::co::tensor::SharedTensor<$t>, |
288 | | - dest: &::co::tensor::SharedTensor<$t>, |
289 | | - filter: &mut ::co::tensor::SharedTensor<$t>, |
290 | | - stride: &[i32], |
291 | | - zero_padding: &[i32] |
292 | | - ) -> Result<Self::CC, ::co::error::Error> { |
293 | | - unimplemented!(); |
294 | | - Ok(helper::ConvolutionConfig) |
295 | | - } |
296 | | - fn convolution( |
297 | | - &self, |
298 | | - x: &mut ::co::tensor::SharedTensor<$t>, |
299 | | - result: &mut ::co::tensor::SharedTensor<$t>, |
300 | | - config: &Self::CC |
301 | | - ) -> Result<(), ::co::error::Error> { |
302 | | - unimplemented!(); |
303 | | - Ok(()) |
304 | | - } |
305 | | - |
306 | | - fn convolution_plain( |
307 | | - &self, |
308 | | - x: &::co::tensor::SharedTensor<$t>, |
309 | | - result: &mut ::co::tensor::SharedTensor<$t>, |
310 | | - config: &Self::CC |
311 | | - ) -> Result<(), ::co::error::Error> { |
312 | | - unimplemented!(); |
313 | | - Ok(()) |
314 | | - } |
315 | | - |
316 | | - fn convolution_grad( |
317 | | - &self, |
318 | | - x: &mut ::co::tensor::SharedTensor<$t>, |
319 | | - x_diff: &mut ::co::tensor::SharedTensor<$t>, |
320 | | - result: &mut ::co::tensor::SharedTensor<$t>, |
321 | | - result_diff: &mut ::co::tensor::SharedTensor<$t>, |
322 | | - config: &Self::CC |
323 | | - ) -> Result<(), ::co::error::Error> { |
324 | | - unimplemented!(); |
325 | | - Ok(()) |
326 | | - } |
327 | | - |
328 | | - fn convolution_grad_plain( |
329 | | - &self, |
330 | | - x: &::co::tensor::SharedTensor<$t>, |
331 | | - x_diff: &::co::tensor::SharedTensor<$t>, |
332 | | - result: &::co::tensor::SharedTensor<$t>, |
333 | | - result_diff: &mut ::co::tensor::SharedTensor<$t>, |
334 | | - config: &Self::CC |
335 | | - ) -> Result<(), ::co::error::Error> { |
336 | | - unimplemented!(); |
337 | | - Ok(()) |
338 | | - } |
339 | | - } |
340 | | - ); |
| 278 | +#[derive(Debug, Clone)] |
| 279 | +#[allow(missing_docs)] |
| 280 | +pub struct ConvolutionConfig { |
| 281 | + pub filter_shape: Vec<usize>, |
| 282 | + pub stride: Vec<i32>, |
| 283 | + pub padding: Vec<i32>, |
341 | 284 | } |
342 | 285 |
|
343 | 286 | #[macro_export] |
|
0 commit comments