Skip to content

Commit b3f5876

Browse files
authored
Merge pull request #14 from PaddlePaddle/develop
merge to local
2 parents 86511f5 + f906179 commit b3f5876

39 files changed

+2910
-601
lines changed

paddle/fluid/API.spec

Lines changed: 15 additions & 17 deletions
Large diffs are not rendered by default.

paddle/fluid/operators/range_op.cc

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/range_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class RangeOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext *ctx) const override {
25+
if (ctx->HasInput("Start")) {
26+
auto s_dims = ctx->GetInputDim("Start");
27+
PADDLE_ENFORCE((s_dims.size() == 1) && (s_dims[0] == 1),
28+
"The shape of Input(Start) should be [1].");
29+
}
30+
if (ctx->HasInput("End")) {
31+
auto e_dims = ctx->GetInputDim("End");
32+
PADDLE_ENFORCE((e_dims.size() == 1) && (e_dims[0] == 1),
33+
"The shape of Input(End) should be [1].");
34+
}
35+
if (ctx->HasInput("Step")) {
36+
auto step_dims = ctx->GetInputDim("Step");
37+
PADDLE_ENFORCE((step_dims.size() == 1) && (step_dims[0] == 1),
38+
"The shape of Input(Step) should be [1].");
39+
}
40+
ctx->SetOutputDim("Out", {-1});
41+
}
42+
};
43+
44+
class RangeOpMaker : public framework::OpProtoAndCheckerMaker {
45+
public:
46+
void Make() override {
47+
AddInput("Start",
48+
"Start of interval. The interval includes this value. It is a "
49+
"tensor with shape=[1].");
50+
AddInput("End",
51+
"End of interval. The interval does not include this value, "
52+
"except in some cases where step is not an integer and floating "
53+
"point round-off affects the length of out. It is a tensor with "
54+
"shape=[1].");
55+
AddInput("Step", "Spacing between values. It is a tensor with shape=[1].");
56+
AddOutput("Out", "A sequence of numbers.");
57+
AddComment(R"DOC(
58+
Return evenly spaced values within a given interval. Values are generated within the half-open interval [start, stop) (in other words, the interval including start but excluding stop). Like arange function of numpy.
59+
)DOC");
60+
}
61+
};
62+
} // namespace operators
63+
} // namespace paddle
64+
65+
namespace ops = paddle::operators;
66+
REGISTER_OP_WITHOUT_GRADIENT(range, ops::RangeOp, ops::RangeOpMaker);
67+
REGISTER_OP_CPU_KERNEL(range, ops::CPURangeKernel<int>,
68+
ops::CPURangeKernel<float>, ops::CPURangeKernel<double>,
69+
ops::CPURangeKernel<int64_t>);

paddle/fluid/operators/range_op.cu

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/operators/range_op.h"
17+
#include "paddle/fluid/platform/cuda_primitives.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
#define CUDA_1D_KERNEL_LOOP(i, n) \
23+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
24+
i += blockDim.x * gridDim.x)
25+
26+
template <typename T>
27+
__global__ void RangeKernel(T start, T step, int64_t size, T* out) {
28+
CUDA_1D_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
29+
}
30+
31+
template <typename T>
32+
class CUDARangeKernel : public framework::OpKernel<T> {
33+
public:
34+
void Compute(const framework::ExecutionContext& context) const override {
35+
auto* start_t = context.Input<framework::Tensor>("Start");
36+
auto* end_t = context.Input<framework::Tensor>("End");
37+
auto* step_t = context.Input<framework::Tensor>("Step");
38+
auto* out = context.Output<framework::Tensor>("Out");
39+
40+
framework::Tensor n;
41+
framework::TensorCopy(*start_t, platform::CPUPlace(), &n);
42+
T start = n.data<T>()[0];
43+
framework::TensorCopy(*end_t, platform::CPUPlace(), &n);
44+
T end = n.data<T>()[0];
45+
framework::TensorCopy(*step_t, platform::CPUPlace(), &n);
46+
T step = n.data<T>()[0];
47+
48+
int64_t size = 0;
49+
GetSize(start, end, step, &size);
50+
out->Resize(framework::make_ddim({size}));
51+
T* out_data = out->mutable_data<T>(context.GetPlace());
52+
53+
auto stream = context.cuda_device_context().stream();
54+
int block = 512;
55+
int grid = (size + block - 1) / block;
56+
RangeKernel<T><<<grid, block, 0, stream>>>(start, step, size, out_data);
57+
}
58+
};
59+
60+
} // namespace operators
61+
} // namespace paddle
62+
63+
namespace ops = paddle::operators;
64+
REGISTER_OP_CUDA_KERNEL(range, ops::CUDARangeKernel<int>,
65+
ops::CUDARangeKernel<int64_t>,
66+
ops::CUDARangeKernel<float>,
67+
ops::CUDARangeKernel<double>);

paddle/fluid/operators/range_op.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <functional>
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/operators/math/math_function.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename T>
24+
void GetSize(T start, T end, T step, int64_t* size) {
25+
PADDLE_ENFORCE(!std::equal_to<T>()(step, 0),
26+
"The step of range op should not be 0.");
27+
PADDLE_ENFORCE(((start < end) && (step > 0)) || ((start > end) && (step < 0)),
28+
"The step should be greater than 0 while start < end. And the "
29+
"step should be less than 0 while start > end.");
30+
*size = std::is_integral<T>::value
31+
? ((std::abs(end - start) + std::abs(step) - 1) / std::abs(step))
32+
: std::ceil(std::abs((end - start) / step));
33+
}
34+
35+
template <typename T>
36+
class CPURangeKernel : public framework::OpKernel<T> {
37+
public:
38+
void Compute(const framework::ExecutionContext& context) const override {
39+
T start = context.Input<framework::Tensor>("Start")->data<T>()[0];
40+
T end = context.Input<framework::Tensor>("End")->data<T>()[0];
41+
T step = context.Input<framework::Tensor>("Step")->data<T>()[0];
42+
auto* out = context.Output<framework::Tensor>("Out");
43+
int64_t size = 0;
44+
GetSize(start, end, step, &size);
45+
out->Resize(framework::make_ddim({size}));
46+
T* out_data = out->mutable_data<T>(context.GetPlace());
47+
T value = start;
48+
for (int64_t i = 0; i < size; ++i) {
49+
out_data[i] = value;
50+
value += step;
51+
}
52+
}
53+
};
54+
55+
} // namespace operators
56+
} // namespace paddle

python/paddle/fluid/contrib/slim/__init__.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,4 @@
1313
# limitations under the License.
1414

1515
from .core import *
16-
from .graph import *
17-
from .prune import *
18-
__all__ = [
19-
'build_compressor',
20-
'CompressPass',
21-
'ImitationGraph',
22-
'SensitivePruneStrategy',
23-
'MagnitudePruner',
24-
'RatioPruner',
25-
]
16+
__all__ = ['Compressor', ]

python/paddle/fluid/contrib/slim/core/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414

1515
from . import config
1616
from .config import *
17-
from . import compress_pass
18-
from .compress_pass import *
17+
from . import compressor
18+
from .compressor import *
1919
from . import strategy
2020
from .strategy import *
21-
from . import pass_builder
22-
from .pass_builder import *
2321

24-
__all__ = config.__all__ + compress_pass.__all__ + strategy.__all__ + pass_builder.__all__
22+
__all__ = config.__all__ + compressor.__all__ + strategy.__all__

python/paddle/fluid/contrib/slim/core/compress_pass.py

Lines changed: 0 additions & 129 deletions
This file was deleted.

0 commit comments

Comments
 (0)