-
Notifications
You must be signed in to change notification settings - Fork 259
Refactor export helpers #958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor export helpers #958
Conversation
@daniil-lyakhov Please, define what is closing prunable layer in PR description and add more details about why is it needed. |
Every prunable layer L should have closing prunable layer means next prunable layer after L which input channels are pruning. Green group is group of prunable layers, which has same channel dimention. We are removing some output channels from them. And yellow prunable layer is closing prunable layer which input channels we are removing. Hope this helped |
ccb0597
to
7c40fc1
Compare
if input_node.data.get('output_mask', None) is not None: | ||
continue | ||
|
||
source_nodes = get_sources_of_node(input_node, graph, cls.ConvolutionOp.get_all_op_aliases() + | ||
cls.StopMaskForwardOp.get_all_op_aliases() + | ||
cls.InputOp.get_all_op_aliases()) | ||
sources_types = [node.node_type for node in source_nodes] + [input_node.node_type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please be aware of two changes here:
if input_node.data.get('output_mask', None) is not None:
instead of
if input_node.data.get('output_mask', None) is None:
and
sources_types = [node.node_type for node in source_nodes] + [input_node.node_type]
instead of
sources_types = [node.node_type for node in source_nodes]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@daniil-lyakhov Please, describe how the second change (of sources_types) will impact how the algorithm works on known models with Concat.
And, please, test these changes carefully on these models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested locally, please check daniil-lyakhov#10
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have run tests on CI #969
ba7140a
to
e022a1f
Compare
Minor TODO:
|
Jenkins please retry a build |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we remove this module and use 'common module in the TensorFlow backend?
if mask is None: | ||
concat_axis = node.layer_attributes.axis | ||
concat_dim = input_edges[i].tensor_shape[concat_axis] | ||
device = cls._get_masks_device(input_masks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explain what type of mask? Can we avoid functions as _get_masks_device
?
23a4ff1
to
0712dff
Compare
0712dff
to
9dd421a
Compare
648b4f2
to
4340499
Compare
…ated from non prunable node
0a972d2
to
5c46689
Compare
* Have separated common parts of pruning export_helpers * Implement base realization of `mask_propogation` methods to have opportunity to use mask propagation algo without any specific framework(tf, torch) functions. The main point of this refactoring is mask propagation algo with framework agnostic masks support. Thus we will be allowed to use mask propogation algo at model analyzing stage to check state: "Every prunable layer have correspondent closing prunable layer". Thus we will be able to propagate masks through reshape layers and will enable one of a realization of SE block in mobilenet_v3. For this, numpy arrays will be propagated by base realization of `mask_propogation` TODO: - [x] Implement tests to cover common and algo specific mask propogation
mask_propogation
methods to have opportunity to use mask propagation algo without any specific framework(tf, torch) functions.The main point of this refactoring is mask propagation algo with framework agnostic masks support. Thus we will be allowed to use mask propogation algo at model analyzing stage to check state: "Every prunable layer have correspondent closing prunable layer". Thus we will be able to propagate masks through reshape layers and will enable one of a realization of SE block in mobilenet_v3. For this, numpy arrays will be propagated by base realization of
mask_propogation
TODO: