44"""
55
66import time # for hash generation
7- from abc import abstractmethod
87from enum import Enum
98from typing import Any , Iterable , List
109
1110import torch
1211
12+ from .common import EXPORT_OVERLAPPED_CONFIG
1313from .storage import Serializable
1414
1515
16+ class QuantizationVisiblity (Enum ):
17+ FORCE_EXPORT = 1
18+ EXPOET_WHEN_ACTIVE = 2
19+ INTERNAL = 3
20+
1621class NetworkFramework (Enum ):
1722 PPL = 1
1823 ONNX = 2
@@ -365,7 +370,7 @@ def __init__(
365370 offset : Any = None ,
366371 observer_algorithm : str = None ,
367372 detail : Any = None ,
368- require_export : bool = None ,
373+ visiblity : QuantizationVisiblity = QuantizationVisiblity . EXPOET_WHEN_ACTIVE ,
369374 state : QuantizationStates = QuantizationStates .INITIAL
370375 ):
371376 """Create a PPQ Tensor Quantization Configuration Instance.
@@ -395,7 +400,13 @@ def __init__(
395400 detail (Any, optional): Only used by PPQ internal logic, detail is used to store some internal data,
396401 you are not supposed to use it.
397402
398- require_export (bool, optional): If require_export == True, PPQ exporter will export this TQC ignoring state checks.
403+ visiblity (Visiblity): visiblity is the attribute that controls export logic.
404+
405+ Currently, there are 3 Visiblity level in PPQ:
406+ if Visiblity == FORCE_EXPORT, ppq exporter will export this TQC
407+ ignoring state check(even if current TQC has been overrlapped).
408+ if Visiblity == EXPORT_WHEN_ACTIVD, ppq exporter will export this TQC only when it has been actived.
409+ if Visiblity == INTERNAL, This TQC will not be exported.
399410
400411 state (QuantizationStates, optional):
401412 Defaults to QuantizationStates.INITIAL, see QuantizationStates for more detail.
@@ -416,17 +427,25 @@ def __init__(
416427 self .detail = {} if detail is None else detail
417428 self ._father_config = self # union-find
418429 self ._hash = self .__create_hash ()
419- self ._require_export = require_export
430+ self ._visiblity = visiblity
420431 super ().__init__ ()
421432
422- @ abstractmethod
423- def export (self ) -> str :
424- raise Exception ('Implement this first' )
433+ def can_export (self ) -> bool :
434+ if self .visiblity == QuantizationVisiblity .INTERNAL : return False
435+ type_check = isinstance (self .scale , torch .Tensor ) and isinstance (self .offset , torch .Tensor )
436+ valid_states = {QuantizationStates .BAKED , QuantizationStates .PASSIVE_BAKED }
437+
438+ if EXPORT_OVERLAPPED_CONFIG : valid_states .add (QuantizationStates .OVERLAPPED )
439+ state_check = QuantizationStates .is_activated (self .state ) or self .state in valid_states
440+
441+ if (state_check or self .visiblity == QuantizationVisiblity .FORCE_EXPORT ):
442+ if type_check : return True
443+ return False
425444
426445 def __eq__ (self , o : object ) -> bool :
427446 if not isinstance (o , TensorQuantizationConfig ):
428- raise TypeError ('Can only compare TensorQuantizationConfig object ' \
429- 'with another TensorQuantizationConfig object.' )
447+ raise TypeError ('Can only compare TensorQuantizationConfig object '
448+ 'with another TensorQuantizationConfig object.' )
430449 return self ._hash == o ._hash
431450
432451 def __str__ (self ) -> str :
@@ -509,17 +528,13 @@ def is_revisable(self):
509528 })
510529
511530 @ property
512- def exportable (self ) -> bool :
513- value_check = isinstance (self .scale , torch .Tensor )
514- if self ._require_export is None :
515- state_check = QuantizationStates .can_export (self .state )
516- return (value_check and state_check )
517- else : return (self ._require_export and value_check )
518-
519- @ exportable .setter
520- def exportable (self , export_override : bool ):
521- self ._require_export = export_override
522-
531+ def visiblity (self ) -> bool :
532+ return self ._visiblity
533+
534+ @ visiblity .setter
535+ def visiblity (self , visiblity : bool ):
536+ self ._visiblity = visiblity
537+
523538 @ property
524539 def scale (self ) -> torch .Tensor :
525540 if self .dominated_by == self : return self ._scale
0 commit comments