@@ -563,6 +563,70 @@ void CLImageConverterNBlock::ImageToNCHW(void *image,
563
563
const DDim &image_dim,
564
564
const DDim &tensor_dim) {}
565
565
566
+ DDim CLImageConverterN2Block::InitImageDimInfoWith (const DDim &tensor_dim) {
567
+ CHECK (tensor_dim.size () == 4 ) << " Tensor dim is not 4." ;
568
+ size_t N, C, H, W;
569
+ N = tensor_dim[0 ];
570
+ C = tensor_dim[1 ];
571
+ H = tensor_dim[2 ];
572
+ W = tensor_dim[3 ];
573
+ size_t width = (C + 3 ) / 4 * 2 * 4 ;
574
+ size_t height = ((N + 7 ) / 8 ) * H * W;
575
+ return DDim (
576
+ std::vector<DDim::value_type>({static_cast <DDim::value_type>(width),
577
+ static_cast <DDim::value_type>(height)}));
578
+ }
579
+
580
+ void CLImageConverterN2Block::NCHWToImage (float *nchw,
581
+ void *image,
582
+ const DDim &tensor_dim) {
583
+ CHECK (tensor_dim.size () == 4 ) << " Tensor dim is not 4." ;
584
+ size_t N, C, H, W;
585
+ N = tensor_dim[0 ];
586
+ C = tensor_dim[1 ];
587
+ H = tensor_dim[2 ];
588
+ W = tensor_dim[3 ];
589
+
590
+ DDim in_image_dim = InitImageDimInfoWith (tensor_dim);
591
+
592
+ VLOG (3 ) << " tensor dim: " << tensor_dim;
593
+ VLOG (3 ) << " image dim: " << in_image_dim;
594
+
595
+ size_t height = in_image_dim[1 ];
596
+ size_t n_block = height / (W * H);
597
+ size_t c_block = (C + 3 ) / 4 ;
598
+
599
+ float *image_fp32 = static_cast <float *>(image);
600
+ half_t *image_fp16 = static_cast <half_t *>(image);
601
+
602
+ float *p = nchw;
603
+ size_t i0 = 0 ;
604
+ for (size_t n = 0 ; n < n_block * 8 ; n++) {
605
+ for (size_t c = 0 ; c < c_block * 4 ; c++) {
606
+ for (size_t h = 0 ; h < H; h++) {
607
+ for (size_t w = 0 ; w < W; w++) {
608
+ size_t img_idx = ((n / 8 ) * W * H + h * W + w) * c_block * 4 * 8 +
609
+ (c / 4 ) * 32 + ((n % 8 ) / 4 ) * 16 + (c % 4 ) * 4 +
610
+ (n % 8 ) % 4 ;
611
+ if (n < N && c < C) {
612
+ fp16_support_ ? image_fp16[img_idx] = Float2Half (*p)
613
+ : image_fp32[img_idx] = *p;
614
+ p++;
615
+ } else {
616
+ fp16_support_ ? image_fp16[img_idx] = Float2Half (0 .f )
617
+ : image_fp32[img_idx] = 0 .f ;
618
+ }
619
+ }
620
+ }
621
+ }
622
+ }
623
+ }
624
+
625
+ void CLImageConverterN2Block::ImageToNCHW (void *image,
626
+ float *tensor,
627
+ const DDim &image_dim,
628
+ const DDim &tensor_dim) {}
629
+
566
630
DDim CLImageConverterDWFilter::InitImageDimInfoWith (const DDim &tensor_dim) {
567
631
CHECK (tensor_dim.size () == 4 ) << " Tensor dim is not 4." ;
568
632
size_t N, C, H, W;
0 commit comments