Skip to content

Commit 4620a8e

Browse files
authored
[metal] fix pooling (#7404) (#7645)
* [metal] fix pooling * Update pool_image_compute.mm
1 parent 8584386 commit 4620a8e

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

lite/kernels/metal/image_op/pool_image_compute.mm

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@
154154
auto cmdbuf = [backend commandBuffer];
155155
if (mps_pool_op_) {
156156
if (@available(iOS 10.0, *)) {
157-
[((__bridge MPSCNNPoolingMax*)mps_pool_op_)
157+
[((__bridge MPSCNNPooling*)mps_pool_op_)
158158
encodeToCommandBuffer:cmdbuf
159159
sourceImage:(__bridge MPSImage*)mps_input_image_
160160
destinationImage:(__bridge MPSImage*)mps_output_image_];
@@ -192,16 +192,21 @@
192192
kernelHeight:kh
193193
strideInPixelsX:sw
194194
strideInPixelsY:sh];
195+
((__bridge MPSCNNPoolingMax*)mps_pool_op_).offset =
196+
MPSOffset{.x = offsetX, .y = offsetY};
197+
((__bridge MPSCNNPoolingMax*)mps_pool_op_).edgeMode = MPSImageEdgeModeZero;
195198
} else if (param.pooling_type == "avg") {
196-
mps_pool_op_ =
197-
(__bridge_retained void*)[[MPSCNNPoolingAverage alloc] initWithDevice:backend.device
198-
kernelWidth:kw
199-
kernelHeight:kh
200-
strideInPixelsX:sw
201-
strideInPixelsY:sh];
199+
mps_pool_op_ = (__bridge_retained void*)[[MPSCNNPoolingAverage alloc]
200+
initWithDevice:backend.device
201+
kernelWidth:input_buffer_->image().width
202+
kernelHeight:input_buffer_->image().height
203+
strideInPixelsX:input_buffer_->image().width
204+
strideInPixelsY:input_buffer_->image().height];
205+
((__bridge MPSCNNPoolingAverage*)mps_pool_op_).offset =
206+
MPSOffset{.x = static_cast<NSInteger>(input_buffer_->image().width / 2),
207+
.y = static_cast<NSInteger>(input_buffer_->image().height / 2)};
208+
((__bridge MPSCNNPoolingAverage*)mps_pool_op_).edgeMode = MPSImageEdgeModeZero;
202209
}
203-
((__bridge MPSCNNPoolingMax*)mps_pool_op_).offset = MPSOffset{.x = offsetX, .y = offsetY};
204-
((__bridge MPSCNNPoolingMax*)mps_pool_op_).edgeMode = MPSImageEdgeModeZero;
205210
// MPS input and output
206211
auto input_c = static_cast<int>(input_buffer_->tensor_dim_[1]);
207212
auto output_c = static_cast<int>(output_buffer_->tensor_dim_[1]);

0 commit comments

Comments
 (0)