Skip to content

Commit 814f163

Browse files
isanghaosshlyapn
andauthored
[GPU] Code refactoring for cloning fc primitive (#30391)
### Details: - Previously, we had member-by-member copy for bias fusion, which is error-prone - Added clone() method to copy FC primitive - Removed const from primitive id because it can be changed after creation --------- Co-authored-by: Sergey Shlyapnikov <[email protected]>
1 parent ff69f1b commit 814f163

File tree

3 files changed

+14
-29
lines changed

3 files changed

+14
-29
lines changed

src/plugins/intel_gpu/include/intel_gpu/primitives/primitive.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ struct primitive {
208208
const primitive_type_id type;
209209

210210
/// @brief Primitive's id.
211-
const primitive_id id;
211+
primitive_id id;
212212

213213
/// @brief Name of original ov operation.
214214
std::string origin_op_name;
@@ -259,7 +259,7 @@ struct primitive {
259259
std::string type_str;
260260
ib >> type_str;
261261
*const_cast<primitive_type_id*>(&type) = prim_map_storage::instance().get_type_id(type_str);
262-
ib >> *const_cast<primitive_id*>(&id);
262+
ib >> id;
263263
ib >> origin_op_name;
264264
ib >> origin_op_type_name;
265265
ib >> output_paddings;
@@ -306,6 +306,9 @@ struct primitive {
306306
/// @brief base class for all primitives implementations.
307307
template <class PType>
308308
class primitive_base : public primitive {
309+
public:
310+
std::shared_ptr<PType> clone() const { return std::make_shared<PType>(static_cast<const PType &>(*this)); }
311+
309312
protected:
310313
explicit primitive_base(const primitive_id& id,
311314
const std::vector<input_info>& input,

src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -462,25 +462,11 @@ void prepare_primitive_fusing::fuse_bias(program &p) {
462462
continue;
463463
}
464464

465-
auto fc_with_bias_prim = std::make_shared<fully_connected>(desc->id + "_tmp",
466-
desc->input[0],
467-
desc->weights,
468-
bias_name,
469-
fc.get_output_layout().data_type,
470-
desc->input_size,
471-
desc->weights_rank);
472-
473-
if (desc->compressed_weights) {
474-
fc_with_bias_prim->compressed_weights = true;
475-
fc_with_bias_prim->decompression_scale = desc->decompression_scale;
476-
fc_with_bias_prim->decompression_zero_point = desc->decompression_zero_point;
477-
if (desc->decompression_zero_point_scalar.has_value())
478-
fc_with_bias_prim->decompression_zero_point_scalar = desc->decompression_zero_point_scalar.value();
479-
fc_with_bias_prim->activation_scale = desc->activation_scale;
480-
fc_with_bias_prim->activation_zero_point = desc->activation_zero_point;
481-
fc_with_bias_prim->dynamic_quantized_activation = desc->dynamic_quantized_activation;
482-
fc_with_bias_prim->dynamic_quantized_activation_zp = desc->dynamic_quantized_activation_zp;
483-
}
465+
auto fc_with_bias_prim = desc->clone();
466+
fc_with_bias_prim->id = desc->id + "_tmp";
467+
fc_with_bias_prim->bias = bias_name;
468+
fc_with_bias_prim->output_data_types = {optional_data_type{fc.get_output_layout().data_type}};
469+
484470
auto& new_fc_node = p.get_or_create(fc_with_bias_prim);
485471
fuse_bias_f(fc, new_fc_node, bias_node, eltw_node);
486472
}

src/plugins/intel_gpu/src/graph/program.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -949,16 +949,12 @@ void program::rename(program_node& node, primitive_id const& new_id) {
949949
nodes_map.emplace(new_id, node_ptr);
950950
nodes_map.erase(node.id());
951951

952-
const_cast<primitive_id&>(node.desc->id) = new_id;
952+
node.desc->id = new_id;
953953
}
954954

955955
void program::swap_names(program_node& node1, program_node& node2) {
956-
const auto _extract_id = [](program_node& node) -> primitive_id& {
957-
return const_cast<primitive_id&>(node.desc->id);
958-
};
959-
960956
nodes_map.at(node1.id()).swap(nodes_map.at(node2.id()));
961-
std::swap(_extract_id(node1), _extract_id(node2));
957+
std::swap(node1.desc->id, node2.desc->id);
962958
}
963959

964960
void program::replace_all_usages(program_node& old_node, program_node& new_node, bool remove_if_dangling) {
@@ -1026,8 +1022,8 @@ void program::replace(program_node& old_node, program_node& new_node) {
10261022
new_node.constant = old_node.constant;
10271023
new_node.data_flow = old_node.data_flow;
10281024
new_node.user_mark = old_node.user_mark;
1029-
const_cast<std::string&>(new_node.desc->origin_op_name) = old_node.desc->origin_op_name;
1030-
const_cast<std::string&>(new_node.desc->origin_op_type_name) = old_node.desc->origin_op_type_name;
1025+
new_node.desc->origin_op_name = old_node.desc->origin_op_name;
1026+
new_node.desc->origin_op_type_name = old_node.desc->origin_op_type_name;
10311027

10321028
processing_order.insert(&old_node, &new_node);
10331029
if (processing_order.get_processing_iterator(old_node) != processing_order.end())

0 commit comments

Comments
 (0)