-
Notifications
You must be signed in to change notification settings - Fork 241
Expand file tree
/
Copy pathHipKernelUtils.cpp
More file actions
115 lines (105 loc) · 4.27 KB
/
HipKernelUtils.cpp
File metadata and controls
115 lines (105 loc) · 4.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "HipKernelUtils.hpp"
#include <hipdnn_plugin_sdk/PluginException.hpp>
namespace hip_kernel_provider::hip_kernel_utils
{
ActivationParams parseActivation(const hipdnn_data_sdk::data_objects::PointwiseAttributes& attrs)
{
using PM = hipdnn_data_sdk::data_objects::PointwiseMode;
switch(attrs.operation())
{
case PM::RELU_FWD:
case PM::RELU_BWD:
{
if(attrs.relu_lower_clip() && attrs.relu_upper_clip())
{
return ActivationParams{ActivationMode::CLAMP,
static_cast<double>(*attrs.relu_lower_clip()),
static_cast<double>(*attrs.relu_upper_clip()),
0.0};
}
if(attrs.relu_upper_clip())
{
return ActivationParams{ActivationMode::CLIPPED_RELU,
static_cast<double>(*attrs.relu_upper_clip()),
0.0,
0.0};
}
if(attrs.relu_lower_clip_slope())
{
return ActivationParams{ActivationMode::LEAKY_RELU,
static_cast<double>(*attrs.relu_lower_clip_slope()),
0.0,
0.0};
}
if(attrs.relu_lower_clip().has_value() && attrs.relu_lower_clip().value() != 0.f)
{
throw hipdnn_plugin_sdk::HipdnnPluginException(
HIPDNN_PLUGIN_STATUS_BAD_PARAM,
"Standard relu with a non-zero lower_clip is not supported");
}
return ActivationParams{ActivationMode::RELU, 0.0, 0.0, 0.0};
}
case PM::SIGMOID_FWD:
case PM::SIGMOID_BWD:
return ActivationParams{ActivationMode::LOGISTIC, 0.0, 0.0, 0.0};
case PM::TANH_FWD:
case PM::TANH_BWD:
return ActivationParams{ActivationMode::TANH, 1.0, 1.0, 0.0};
case PM::ELU_FWD:
case PM::ELU_BWD:
{
double alpha = attrs.elu_alpha() ? static_cast<double>(*attrs.elu_alpha()) : 1.0;
return ActivationParams{ActivationMode::ELU, alpha, 0.0, 0.0};
}
case PM::SOFTPLUS_FWD:
case PM::SOFTPLUS_BWD:
if(attrs.softplus_beta())
{
if(static_cast<double>(*attrs.softplus_beta()) != 1.0)
{
throw hipdnn_plugin_sdk::HipdnnPluginException(HIPDNN_PLUGIN_STATUS_BAD_PARAM,
"Softplus only supports beta = 1.0");
}
}
return ActivationParams{ActivationMode::SOFTRELU, 0.0, 0.0, 0.0};
case PM::ABS:
return ActivationParams{ActivationMode::ABS, 0.0, 0.0, 0.0};
case PM::IDENTITY:
return ActivationParams{ActivationMode::PASTHRU, 0.0, 0.0, 0.0};
default:
throw hipdnn_plugin_sdk::HipdnnPluginException(HIPDNN_PLUGIN_STATUS_BAD_PARAM,
"Unsupported activation operation");
}
}
hipdnnPluginDeviceBuffer_t findDeviceBuffer(int64_t uid,
const hipdnnPluginDeviceBuffer_t* deviceBuffers,
uint32_t numDeviceBuffers)
{
for(uint32_t i = 0; i < numDeviceBuffers; i++)
{
if(uid == deviceBuffers[i].uid)
{
return deviceBuffers[i];
}
}
throw hipdnn_plugin_sdk::HipdnnPluginException(
HIPDNN_PLUGIN_STATUS_INVALID_VALUE,
"Device buffer with the uid: " + std::to_string(uid)
+ " not found in the provided device buffers.");
}
const hipdnn_data_sdk::data_objects::TensorAttributes& findTensorAttributes(
const std::unordered_map<int64_t, const hipdnn_data_sdk::data_objects::TensorAttributes*>&
tensorMap,
int64_t uid)
{
if(auto tensorAttr = tensorMap.find(uid); tensorAttr != tensorMap.end())
{
return *tensorAttr->second;
}
throw hipdnn_plugin_sdk::HipdnnPluginException(HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR,
"Failed to find tensor with UID in tensorMap: "
+ std::to_string(uid));
}
} // namespace hip_kernel_provider::hip_kernel_utils