@@ -650,20 +650,24 @@ <h1>Source code for compressai.entropy_models.entropy_models</h1><div class="hig
650650 < span class ="n "> scale</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> init_scale</ span > < span class ="o "> **</ span > < span class ="p "> (</ span > < span class ="mi "> 1</ span > < span class ="o "> /</ span > < span class ="p "> (</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> filters</ span > < span class ="p "> )</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span > < span class ="p "> ))</ span >
651651 < span class ="n "> channels</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> channels</ span >
652652
653+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> matrices</ span > < span class ="o "> =</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> ParameterList</ span > < span class ="p "> ()</ span >
654+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> biases</ span > < span class ="o "> =</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> ParameterList</ span > < span class ="p "> ()</ span >
655+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> factors</ span > < span class ="o "> =</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> ParameterList</ span > < span class ="p "> ()</ span >
656+
653657 < span class ="k "> for</ span > < span class ="n "> i</ span > < span class ="ow "> in</ span > < span class ="nb "> range</ span > < span class ="p "> (</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> filters</ span > < span class ="p "> )</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span > < span class ="p "> ):</ span >
654658 < span class ="n "> init</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> log</ span > < span class ="p "> (</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> expm1</ span > < span class ="p "> (</ span > < span class ="mi "> 1</ span > < span class ="o "> /</ span > < span class ="n "> scale</ span > < span class ="o "> /</ span > < span class ="n "> filters</ span > < span class ="p "> [</ span > < span class ="n "> i</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span > < span class ="p "> ]))</ span >
655659 < span class ="n "> matrix</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> (</ span > < span class ="n "> channels</ span > < span class ="p "> ,</ span > < span class ="n "> filters</ span > < span class ="p "> [</ span > < span class ="n "> i</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span > < span class ="p "> ],</ span > < span class ="n "> filters</ span > < span class ="p "> [</ span > < span class ="n "> i</ span > < span class ="p "> ])</ span >
656660 < span class ="n "> matrix</ span > < span class ="o "> .</ span > < span class ="n "> data</ span > < span class ="o "> .</ span > < span class ="n "> fill_</ span > < span class ="p "> (</ span > < span class ="n "> init</ span > < span class ="p "> )</ span >
657- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> register_parameter </ span > < span class ="p " > ( </ span > < span class ="sa " > f </ span > < span class =" s2 " > "_matrix </ span > < span class =" si " > { </ span > < span class =" n "> i </ span > < span class ="si " > : </ span > < span class =" s2 " > d </ span > < span class =" si " > } </ span > < span class =" s2 " > " </ span > < span class =" p "> , </ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Parameter</ span > < span class ="p "> (</ span > < span class ="n "> matrix</ span > < span class ="p "> ))</ span >
661+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> matrices </ span > < span class ="o " > . </ span > < span class ="n "> append </ span > < span class ="p "> ( </ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Parameter</ span > < span class ="p "> (</ span > < span class ="n "> matrix</ span > < span class ="p "> ))</ span >
658662
659663 < span class ="n "> bias</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> (</ span > < span class ="n "> channels</ span > < span class ="p "> ,</ span > < span class ="n "> filters</ span > < span class ="p "> [</ span > < span class ="n "> i</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span > < span class ="p "> ],</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span >
660664 < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> init</ span > < span class ="o "> .</ span > < span class ="n "> uniform_</ span > < span class ="p "> (</ span > < span class ="n "> bias</ span > < span class ="p "> ,</ span > < span class ="o "> -</ span > < span class ="mf "> 0.5</ span > < span class ="p "> ,</ span > < span class ="mf "> 0.5</ span > < span class ="p "> )</ span >
661- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> register_parameter </ span > < span class ="p " > ( </ span > < span class ="sa " > f </ span > < span class =" s2 " > "_bias </ span > < span class =" si " > { </ span > < span class =" n "> i </ span > < span class ="si " > : </ span > < span class =" s2 " > d </ span > < span class =" si " > } </ span > < span class =" s2 " > " </ span > < span class =" p "> , </ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Parameter</ span > < span class ="p "> (</ span > < span class ="n "> bias</ span > < span class ="p "> ))</ span >
665+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> biases </ span > < span class ="o " > . </ span > < span class ="n "> append </ span > < span class ="p "> ( </ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Parameter</ span > < span class ="p "> (</ span > < span class ="n "> bias</ span > < span class ="p "> ))</ span >
662666
663667 < span class ="k "> if</ span > < span class ="n "> i</ span > < span class ="o "> <</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> filters</ span > < span class ="p "> ):</ span >
664668 < span class ="n "> factor</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> (</ span > < span class ="n "> channels</ span > < span class ="p "> ,</ span > < span class ="n "> filters</ span > < span class ="p "> [</ span > < span class ="n "> i</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span > < span class ="p "> ],</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span >
665669 < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> init</ span > < span class ="o "> .</ span > < span class ="n "> zeros_</ span > < span class ="p "> (</ span > < span class ="n "> factor</ span > < span class ="p "> )</ span >
666- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> register_parameter </ span > < span class ="p " > ( </ span > < span class ="sa " > f </ span > < span class =" s2 " > "_factor </ span > < span class =" si " > { </ span > < span class =" n "> i </ span > < span class ="si " > : </ span > < span class =" s2 " > d </ span > < span class =" si " > } </ span > < span class =" s2 " > " </ span > < span class =" p "> , </ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Parameter</ span > < span class ="p "> (</ span > < span class ="n "> factor</ span > < span class ="p "> ))</ span >
670+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> factors </ span > < span class ="o " > . </ span > < span class ="n "> append </ span > < span class ="p "> ( </ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Parameter</ span > < span class ="p "> (</ span > < span class ="n "> factor</ span > < span class ="p "> ))</ span >
667671
668672 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> quantiles</ span > < span class ="o "> =</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Parameter</ span > < span class ="p "> (</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> (</ span > < span class ="n "> channels</ span > < span class ="p "> ,</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="mi "> 3</ span > < span class ="p "> ))</ span >
669673 < span class ="n "> init</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ([</ span > < span class ="o "> -</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> init_scale</ span > < span class ="p "> ,</ span > < span class ="mi "> 0</ span > < span class ="p "> ,</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> init_scale</ span > < span class ="p "> ])</ span >
@@ -723,24 +727,23 @@ <h1>Source code for compressai.entropy_models.entropy_models</h1><div class="hig
723727 < span class ="c1 "> # TorchScript not yet working (nn.Mmodule indexing not supported)</ span >
724728 < span class ="n "> logits</ span > < span class ="o "> =</ span > < span class ="n "> inputs</ span >
725729 < span class ="k "> for</ span > < span class ="n "> i</ span > < span class ="ow "> in</ span > < span class ="nb "> range</ span > < span class ="p "> (</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> filters</ span > < span class ="p "> )</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span > < span class ="p "> ):</ span >
726- < span class ="n "> matrix</ span > < span class ="o "> =</ span > < span class ="nb " > getattr </ span > < span class =" p " > ( </ span > < span class =" bp "> self</ span > < span class ="p " > , </ span > < span class ="sa " > f </ span > < span class ="s2 " > "_matrix </ span > < span class ="si " > { </ span > < span class =" n "> i</ span > < span class ="si " > : </ span > < span class =" s2 " > d </ span > < span class =" si " > } </ span > < span class =" s2 " > " </ span > < span class =" p "> ) </ span >
730+ < span class ="n "> matrix</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o " > . </ span > < span class ="n " > matrices </ span > < span class ="p " > [ </ span > < span class ="n "> i</ span > < span class ="p "> ] </ span >
727731 < span class ="k "> if</ span > < span class ="n "> stop_gradient</ span > < span class ="p "> :</ span >
728732 < span class ="n "> matrix</ span > < span class ="o "> =</ span > < span class ="n "> matrix</ span > < span class ="o "> .</ span > < span class ="n "> detach</ span > < span class ="p "> ()</ span >
729733 < span class ="n "> logits</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> matmul</ span > < span class ="p "> (</ span > < span class ="n "> F</ span > < span class ="o "> .</ span > < span class ="n "> softplus</ span > < span class ="p "> (</ span > < span class ="n "> matrix</ span > < span class ="p "> ),</ span > < span class ="n "> logits</ span > < span class ="p "> )</ span >
730734
731- < span class ="n "> bias</ span > < span class ="o "> =</ span > < span class ="nb " > getattr </ span > < span class =" p " > ( </ span > < span class =" bp "> self</ span > < span class ="p " > , </ span > < span class ="sa " > f </ span > < span class ="s2 " > "_bias </ span > < span class ="si " > { </ span > < span class =" n "> i</ span > < span class ="si " > : </ span > < span class =" s2 " > d </ span > < span class =" si " > } </ span > < span class =" s2 " > " </ span > < span class =" p "> ) </ span >
735+ < span class ="n "> bias</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o " > . </ span > < span class ="n " > biases </ span > < span class ="p " > [ </ span > < span class ="n "> i</ span > < span class ="p "> ] </ span >
732736 < span class ="k "> if</ span > < span class ="n "> stop_gradient</ span > < span class ="p "> :</ span >
733737 < span class ="n "> bias</ span > < span class ="o "> =</ span > < span class ="n "> bias</ span > < span class ="o "> .</ span > < span class ="n "> detach</ span > < span class ="p "> ()</ span >
734- < span class ="n "> logits</ span > < span class ="o "> += </ span > < span class ="n "> bias</ span >
738+ < span class ="n "> logits</ span > < span class ="o "> = </ span > < span class =" n " > logits </ span > < span class =" o " > + </ span > < span class ="n "> bias</ span >
735739
736740 < span class ="k "> if</ span > < span class ="n "> i</ span > < span class ="o "> <</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> filters</ span > < span class ="p "> ):</ span >
737- < span class ="n "> factor</ span > < span class ="o "> =</ span > < span class ="nb " > getattr </ span > < span class =" p " > ( </ span > < span class =" bp "> self</ span > < span class ="p " > , </ span > < span class ="sa " > f </ span > < span class ="s2 " > "_factor </ span > < span class ="si " > { </ span > < span class =" n "> i</ span > < span class ="si " > : </ span > < span class =" s2 " > d </ span > < span class =" si " > } </ span > < span class =" s2 " > " </ span > < span class =" p "> ) </ span >
741+ < span class ="n "> factor</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o " > . </ span > < span class ="n " > factors </ span > < span class ="p " > [ </ span > < span class ="n "> i</ span > < span class ="p "> ] </ span >
738742 < span class ="k "> if</ span > < span class ="n "> stop_gradient</ span > < span class ="p "> :</ span >
739743 < span class ="n "> factor</ span > < span class ="o "> =</ span > < span class ="n "> factor</ span > < span class ="o "> .</ span > < span class ="n "> detach</ span > < span class ="p "> ()</ span >
740- < span class ="n "> logits</ span > < span class ="o "> += </ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> tanh</ span > < span class ="p "> (</ span > < span class ="n "> factor</ span > < span class ="p "> )</ span > < span class ="o "> *</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> tanh</ span > < span class ="p "> (</ span > < span class ="n "> logits</ span > < span class ="p "> )</ span >
744+ < span class ="n "> logits</ span > < span class ="o "> = </ span > < span class =" n " > logits </ span > < span class =" o " > + </ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> tanh</ span > < span class ="p "> (</ span > < span class ="n "> factor</ span > < span class ="p "> )</ span > < span class ="o "> *</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> tanh</ span > < span class ="p "> (</ span > < span class ="n "> logits</ span > < span class ="p "> )</ span >
741745 < span class ="k "> return</ span > < span class ="n "> logits</ span >
742746
743- < span class ="nd "> @torch</ span > < span class ="o "> .</ span > < span class ="n "> jit</ span > < span class ="o "> .</ span > < span class ="n "> unused</ span >
744747 < span class ="k "> def</ span > < span class ="nf "> _likelihood</ span > < span class ="p "> (</ span >
745748 < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> inputs</ span > < span class ="p "> :</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> stop_gradient</ span > < span class ="p "> :</ span > < span class ="nb "> bool</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span >
746749 < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Tuple</ span > < span class ="p "> [</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> Tensor</ span > < span class ="p "> ]:</ span >
@@ -758,10 +761,13 @@ <h1>Source code for compressai.entropy_models.entropy_models</h1><div class="hig
758761
759762 < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> jit</ span > < span class ="o "> .</ span > < span class ="n "> is_scripting</ span > < span class ="p "> ():</ span >
760763 < span class ="c1 "> # x from B x C x ... to C x B x ...</ span >
761- < span class ="n "> perm</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> arange</ span > < span class ="p "> (</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> ))</ span >
762- < span class ="n "> perm</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ],</ span > < span class ="n "> perm</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> perm</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ],</ span > < span class ="n "> perm</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span >
763- < span class ="c1 "> # Compute inverse permutation</ span >
764- < span class ="n "> inv_perm</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> arange</ span > < span class ="p "> (</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> ))[</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> argsort</ span > < span class ="p "> (</ span > < span class ="n "> perm</ span > < span class ="p "> )]</ span >
764+ < span class ="n "> perm</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> cat</ span > < span class ="p "> (</ span >
765+ < span class ="p "> (</ span >
766+ < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> tensor</ span > < span class ="p "> ([</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="mi "> 0</ span > < span class ="p "> ],</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> long</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> ),</ span >
767+ < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> arange</ span > < span class ="p "> (</ span > < span class ="mi "> 2</ span > < span class ="p "> ,</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> ndim</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> long</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> ),</ span >
768+ < span class ="p "> )</ span >
769+ < span class ="p "> )</ span >
770+ < span class ="n "> inv_perm</ span > < span class ="o "> =</ span > < span class ="n "> perm</ span >
765771 < span class ="k "> else</ span > < span class ="p "> :</ span >
766772 < span class ="k "> raise</ span > < span class ="ne "> NotImplementedError</ span > < span class ="p "> ()</ span >
767773 < span class ="c1 "> # TorchScript in 2D for static inference</ span >
0 commit comments