|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +import torch.nn.functional as F |
| 7 | + |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +from block_zoo.BaseLayer import BaseLayer, BaseConf |
| 11 | +from utils.DocInherit import DocInherit |
| 12 | + |
| 13 | + |
| 14 | +class Pooling1DConf(BaseConf): |
| 15 | + """ |
| 16 | +
|
| 17 | + Args: |
| 18 | + pool_type (str): 'max' or 'mean', default is 'max'. |
| 19 | + stride (int): which axis to conduct pooling, default is 1. |
| 20 | + padding (int): implicit zero paddings on both sides of the input. Can be a single number or a tuple (padH, padW). Default: 0 |
| 21 | + window_size (int): the size of the pooling |
| 22 | +
|
| 23 | + """ |
| 24 | + |
| 25 | + def __init__(self, **kwargs): |
| 26 | + super(Pooling1DConf, self).__init__(**kwargs) |
| 27 | + |
| 28 | + @DocInherit |
| 29 | + def default(self): |
| 30 | + self.pool_type = 'max' # Supported: ['max', mean'] |
| 31 | + self.stride = 1 |
| 32 | + self.padding = 0 |
| 33 | + self.window_size = 3 |
| 34 | + |
| 35 | + @DocInherit |
| 36 | + def declare(self): |
| 37 | + self.num_of_inputs = 1 |
| 38 | + self.input_ranks = [3] |
| 39 | + |
| 40 | + |
| 41 | + @DocInherit |
| 42 | + def inference(self): |
| 43 | + |
| 44 | + self.output_dim = [self.input_dims[0][0]] |
| 45 | + if self.input_dims[0][1] != -1: |
| 46 | + self.output_dim.append( |
| 47 | + (self.input_dims[0][1] + 2 * self.padding - self.window_size) // self.stride + 1) |
| 48 | + else: |
| 49 | + self.output_dim.append(-1) |
| 50 | + |
| 51 | + self.output_dim.append(self.input_dims[0][-1]) |
| 52 | + # DON'T MODIFY THIS |
| 53 | + self.output_rank = len(self.output_dim) |
| 54 | + |
| 55 | + @DocInherit |
| 56 | + def verify(self): |
| 57 | + super(Pooling1DConf, self).verify() |
| 58 | + |
| 59 | + necessary_attrs_for_user = ['pool_type'] |
| 60 | + for attr in necessary_attrs_for_user: |
| 61 | + self.add_attr_exist_assertion_for_user(attr) |
| 62 | + |
| 63 | + self.add_attr_value_assertion('pool_type', ['max', 'mean']) |
| 64 | + |
| 65 | + assert self.output_dim[ |
| 66 | + -1] != -1, "The shape of input is %s , and the input channel number of pooling should not be -1." % ( |
| 67 | + str(self.input_dims[0])) |
| 68 | + |
| 69 | + |
| 70 | +class Pooling1D(BaseLayer): |
| 71 | + """ Pooling layer |
| 72 | +
|
| 73 | + Args: |
| 74 | + layer_conf (PoolingConf): configuration of a layer |
| 75 | + """ |
| 76 | + |
| 77 | + def __init__(self, layer_conf): |
| 78 | + super(Pooling1D, self).__init__(layer_conf) |
| 79 | + self.pool = None |
| 80 | + if layer_conf.pool_type == "max": |
| 81 | + self.pool = nn.MaxPool1d(kernel_size=layer_conf.window_size, stride=layer_conf.stride, |
| 82 | + padding=layer_conf.padding) |
| 83 | + elif layer_conf.pool_type == "mean": |
| 84 | + self.pool = nn.AvgPool1d(kernel_size=layer_conf.window_size, stride=layer_conf.stride, |
| 85 | + padding=layer_conf.padding) |
| 86 | + |
| 87 | + def forward(self, string, string_len=None): |
| 88 | + """ process inputs |
| 89 | +
|
| 90 | + Args: |
| 91 | + string (Tensor): tensor with shape: [batch_size, length, feature_dim] |
| 92 | + string_len (Tensor): [batch_size], default is None. |
| 93 | +
|
| 94 | + Returns: |
| 95 | + Tensor: Pooling result of string |
| 96 | +
|
| 97 | + """ |
| 98 | + |
| 99 | + string = string.permute([0, 2, 1]).contiguous() |
| 100 | + string = self.pool(string) |
| 101 | + string = string.permute([0, 2, 1]).contiguous() |
| 102 | + return string, string_len |
| 103 | + |
| 104 | + |
0 commit comments