|
| 1 | +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include <limits> |
| 16 | +#include <random> |
| 17 | +#include "paddle/fluid/framework/op_registry.h" |
| 18 | + |
| 19 | +namespace paddle { |
| 20 | +namespace operators { |
| 21 | + |
| 22 | +// reference: https://gist.github.com/lakshayg/d80172fe5ae3c5d2c2aedb53c250320e |
| 23 | +template <typename T> |
| 24 | +T Erfinv(T x) { |
| 25 | + if (x < -1 || x > 1) { |
| 26 | + return std::numeric_limits<T>::quiet_NaN(); |
| 27 | + } else if (x == 1.0) { |
| 28 | + return std::numeric_limits<T>::infinity(); |
| 29 | + } else if (x == -1.0) { |
| 30 | + return -std::numeric_limits<T>::infinity(); |
| 31 | + } |
| 32 | + |
| 33 | + const T LN2 = 6.931471805599453094172321214581e-1; |
| 34 | + |
| 35 | + const T A0 = 1.1975323115670912564578e0; |
| 36 | + const T A1 = 4.7072688112383978012285e1; |
| 37 | + const T A2 = 6.9706266534389598238465e2; |
| 38 | + const T A3 = 4.8548868893843886794648e3; |
| 39 | + const T A4 = 1.6235862515167575384252e4; |
| 40 | + const T A5 = 2.3782041382114385731252e4; |
| 41 | + const T A6 = 1.1819493347062294404278e4; |
| 42 | + const T A7 = 8.8709406962545514830200e2; |
| 43 | + |
| 44 | + const T B0 = 1.0000000000000000000e0; |
| 45 | + const T B1 = 4.2313330701600911252e1; |
| 46 | + const T B2 = 6.8718700749205790830e2; |
| 47 | + const T B3 = 5.3941960214247511077e3; |
| 48 | + const T B4 = 2.1213794301586595867e4; |
| 49 | + const T B5 = 3.9307895800092710610e4; |
| 50 | + const T B6 = 2.8729085735721942674e4; |
| 51 | + const T B7 = 5.2264952788528545610e3; |
| 52 | + |
| 53 | + const T C0 = 1.42343711074968357734e0; |
| 54 | + const T C1 = 4.63033784615654529590e0; |
| 55 | + const T C2 = 5.76949722146069140550e0; |
| 56 | + const T C3 = 3.64784832476320460504e0; |
| 57 | + const T C4 = 1.27045825245236838258e0; |
| 58 | + const T C5 = 2.41780725177450611770e-1; |
| 59 | + const T C6 = 2.27238449892691845833e-2; |
| 60 | + const T C7 = 7.74545014278341407640e-4; |
| 61 | + |
| 62 | + const T D0 = 1.4142135623730950488016887e0; |
| 63 | + const T D1 = 2.9036514445419946173133295e0; |
| 64 | + const T D2 = 2.3707661626024532365971225e0; |
| 65 | + const T D3 = 9.7547832001787427186894837e-1; |
| 66 | + const T D4 = 2.0945065210512749128288442e-1; |
| 67 | + const T D5 = 2.1494160384252876777097297e-2; |
| 68 | + const T D6 = 7.7441459065157709165577218e-4; |
| 69 | + const T D7 = 1.4859850019840355905497876e-9; |
| 70 | + |
| 71 | + const T E0 = 6.65790464350110377720e0; |
| 72 | + const T E1 = 5.46378491116411436990e0; |
| 73 | + const T E2 = 1.78482653991729133580e0; |
| 74 | + const T E3 = 2.96560571828504891230e-1; |
| 75 | + const T E4 = 2.65321895265761230930e-2; |
| 76 | + const T E5 = 1.24266094738807843860e-3; |
| 77 | + const T E6 = 2.71155556874348757815e-5; |
| 78 | + const T E7 = 2.01033439929228813265e-7; |
| 79 | + |
| 80 | + const T F0 = 1.414213562373095048801689e0; |
| 81 | + const T F1 = 8.482908416595164588112026e-1; |
| 82 | + const T F2 = 1.936480946950659106176712e-1; |
| 83 | + const T F3 = 2.103693768272068968719679e-2; |
| 84 | + const T F4 = 1.112800997078859844711555e-3; |
| 85 | + const T F5 = 2.611088405080593625138020e-5; |
| 86 | + const T F6 = 2.010321207683943062279931e-7; |
| 87 | + const T F7 = 2.891024605872965461538222e-15; |
| 88 | + |
| 89 | + T abs_x = abs(x); |
| 90 | + |
| 91 | + if (abs_x <= 0.85) { |
| 92 | + T r = 0.180625 - 0.25 * x * x; |
| 93 | + T num = |
| 94 | + (((((((A7 * r + A6) * r + A5) * r + A4) * r + A3) * r + A2) * r + A1) * |
| 95 | + r + |
| 96 | + A0); |
| 97 | + T den = |
| 98 | + (((((((B7 * r + B6) * r + B5) * r + B4) * r + B3) * r + B2) * r + B1) * |
| 99 | + r + |
| 100 | + B0); |
| 101 | + return x * num / den; |
| 102 | + } |
| 103 | + |
| 104 | + T r = sqrt(LN2 - log(1.0 - abs_x)); |
| 105 | + |
| 106 | + T num, den; |
| 107 | + if (r <= 5.0) { |
| 108 | + r = r - 1.6; |
| 109 | + num = |
| 110 | + (((((((C7 * r + C6) * r + C5) * r + C4) * r + C3) * r + C2) * r + C1) * |
| 111 | + r + |
| 112 | + C0); |
| 113 | + den = |
| 114 | + (((((((D7 * r + D6) * r + D5) * r + D4) * r + D3) * r + D2) * r + D1) * |
| 115 | + r + |
| 116 | + D0); |
| 117 | + } else { |
| 118 | + r = r - 5.0; |
| 119 | + num = |
| 120 | + (((((((E7 * r + E6) * r + E5) * r + E4) * r + E3) * r + E2) * r + E1) * |
| 121 | + r + |
| 122 | + E0); |
| 123 | + den = |
| 124 | + (((((((F7 * r + F6) * r + F5) * r + F4) * r + F3) * r + F2) * r + F1) * |
| 125 | + r + |
| 126 | + F0); |
| 127 | + } |
| 128 | + |
| 129 | + if (x < 0) { |
| 130 | + return -num / den; |
| 131 | + } else { |
| 132 | + return num / den; |
| 133 | + } |
| 134 | +} |
| 135 | + |
| 136 | +template <typename T> |
| 137 | +struct TruncatedNormal { |
| 138 | + T mean, std; |
| 139 | + T a_normal_cdf; |
| 140 | + T b_normal_cdf; |
| 141 | + TruncatedNormal(T mean, T std) : mean(mean), std(std) { |
| 142 | + auto normal_cdf = [](T x) { |
| 143 | + return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; |
| 144 | + }; |
| 145 | + a_normal_cdf = normal_cdf(-2.0); |
| 146 | + b_normal_cdf = normal_cdf(2.0); |
| 147 | + } |
| 148 | + |
| 149 | + T operator()(T value) const { |
| 150 | + auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; |
| 151 | + return (std::sqrt(2.0) * Erfinv(2 * p - 1) + mean) * std; |
| 152 | + } |
| 153 | +}; |
| 154 | + |
| 155 | +template <typename T> |
| 156 | +class CPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> { |
| 157 | + public: |
| 158 | + void Compute(const framework::ExecutionContext& context) const override { |
| 159 | + float mean = context.Attr<float>("mean"); |
| 160 | + float std = context.Attr<float>("std"); |
| 161 | + auto* tensor = context.Output<framework::Tensor>("Out"); |
| 162 | + T* data = tensor->mutable_data<T>(context.GetPlace()); |
| 163 | + |
| 164 | + unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed")); |
| 165 | + std::minstd_rand engine; |
| 166 | + if (seed == 0) { |
| 167 | + seed = std::random_device()(); |
| 168 | + } |
| 169 | + engine.seed(seed); |
| 170 | + std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(), |
| 171 | + 1.0); |
| 172 | + TruncatedNormal<T> truncated_normal(mean, std); |
| 173 | + int64_t size = tensor->numel(); |
| 174 | + for (int64_t i = 0; i < size; ++i) { |
| 175 | + data[i] = truncated_normal(dist(engine)); |
| 176 | + } |
| 177 | + } |
| 178 | +}; |
| 179 | + |
| 180 | +class TruncatedGaussianRandomOp : public framework::OperatorWithKernel { |
| 181 | + public: |
| 182 | + using framework::OperatorWithKernel::OperatorWithKernel; |
| 183 | + |
| 184 | + void InferShape(framework::InferShapeContext* ctx) const override { |
| 185 | + PADDLE_ENFORCE( |
| 186 | + ctx->HasOutput("Out"), |
| 187 | + "Output(Out) of TruncatedGaussianRandomOp should not be null."); |
| 188 | + auto shape = ctx->Attrs().Get<std::vector<int>>("shape"); |
| 189 | + std::vector<int64_t> out_dim; |
| 190 | + out_dim.reserve(shape.size()); |
| 191 | + for (auto dim : shape) { |
| 192 | + out_dim.push_back(static_cast<int64_t>(dim)); |
| 193 | + } |
| 194 | + PADDLE_ENFORCE(shape.size() > 0UL, |
| 195 | + "shape can be one int or array. shape must be set."); |
| 196 | + ctx->SetOutputDim("Out", framework::make_ddim(out_dim)); |
| 197 | + } |
| 198 | + |
| 199 | + protected: |
| 200 | + framework::OpKernelType GetExpectedKernelType( |
| 201 | + const framework::ExecutionContext& ctx) const override { |
| 202 | + framework::LibraryType library{framework::LibraryType::kPlain}; |
| 203 | + framework::DataLayout layout{framework::DataLayout::kAnyLayout}; |
| 204 | + return framework::OpKernelType( |
| 205 | + static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")), |
| 206 | + ctx.device_context(), layout, library); |
| 207 | + } |
| 208 | +}; |
| 209 | + |
| 210 | +class TruncatedGaussianRandomOpMaker |
| 211 | + : public framework::OpProtoAndCheckerMaker { |
| 212 | + public: |
| 213 | + void Make() override { |
| 214 | + AddOutput("Out", "Output tensor of truncated gaussian random op."); |
| 215 | + |
| 216 | + AddAttr<std::vector<int>>("shape", |
| 217 | + "(vector<int>) " |
| 218 | + "The dimension of random tensor."); |
| 219 | + AddAttr<float>("mean", |
| 220 | + "(float, default 0.0) " |
| 221 | + "mean of random tensor.") |
| 222 | + .SetDefault(.0f); |
| 223 | + AddAttr<float>("std", |
| 224 | + "(float, default 1.0) " |
| 225 | + "std of random tensor.") |
| 226 | + .SetDefault(1.0f); |
| 227 | + AddAttr<int>("seed", |
| 228 | + "(int, default 0) " |
| 229 | + "Random seed of generator." |
| 230 | + "0 means use system wide seed." |
| 231 | + "Note that if seed is not 0, this operator will always " |
| 232 | + "generate the same random numbers every time.") |
| 233 | + .SetDefault(0); |
| 234 | + AddAttr<int>("dtype", |
| 235 | + "(int, default 5(FP32)) " |
| 236 | + "Output data type.") |
| 237 | + .SetDefault(framework::proto::VarType::FP32); |
| 238 | + AddComment(R"DOC( |
| 239 | +TruncatedGaussianRandom Operator. |
| 240 | +
|
| 241 | +Used to initialize tensors with truncated gaussian random generator. |
| 242 | +
|
| 243 | +)DOC"); |
| 244 | + } |
| 245 | +}; |
| 246 | + |
| 247 | +} // namespace operators |
| 248 | +} // namespace paddle |
| 249 | + |
| 250 | +namespace ops = paddle::operators; |
| 251 | +REGISTER_OP_WITHOUT_GRADIENT(truncated_gaussian_random, |
| 252 | + ops::TruncatedGaussianRandomOp, |
| 253 | + ops::TruncatedGaussianRandomOpMaker); |
| 254 | +REGISTER_OP_CPU_KERNEL(truncated_gaussian_random, |
| 255 | + ops::CPUTruncatedGaussianRandomKernel<float>); |
0 commit comments