Skip to content

Commit 3fe0a29

Browse files
committed
deploy: 721fdee
1 parent 8e9be4c commit 3fe0a29

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

_modules/compressai/entropy_models/entropy_models.html

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -650,20 +650,24 @@ <h1>Source code for compressai.entropy_models.entropy_models</h1><div class="hig
650650
<span class="n">scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_scale</span> <span class="o">**</span> <span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
651651
<span class="n">channels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">channels</span>
652652

653+
<span class="bp">self</span><span class="o">.</span><span class="n">matrices</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ParameterList</span><span class="p">()</span>
654+
<span class="bp">self</span><span class="o">.</span><span class="n">biases</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ParameterList</span><span class="p">()</span>
655+
<span class="bp">self</span><span class="o">.</span><span class="n">factors</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ParameterList</span><span class="p">()</span>
656+
653657
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
654658
<span class="n">init</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">expm1</span><span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="n">scale</span> <span class="o">/</span> <span class="n">filters</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]))</span>
655659
<span class="n">matrix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">filters</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span> <span class="n">filters</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
656660
<span class="n">matrix</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">fill_</span><span class="p">(</span><span class="n">init</span><span class="p">)</span>
657-
<span class="bp">self</span><span class="o">.</span><span class="n">register_parameter</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;_matrix</span><span class="si">{</span><span class="n">i</span><span class="si">:</span><span class="s2">d</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">matrix</span><span class="p">))</span>
661+
<span class="bp">self</span><span class="o">.</span><span class="n">matrices</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">matrix</span><span class="p">))</span>
658662

659663
<span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">filters</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
660664
<span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">uniform_</span><span class="p">(</span><span class="n">bias</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)</span>
661-
<span class="bp">self</span><span class="o">.</span><span class="n">register_parameter</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;_bias</span><span class="si">{</span><span class="n">i</span><span class="si">:</span><span class="s2">d</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">bias</span><span class="p">))</span>
665+
<span class="bp">self</span><span class="o">.</span><span class="n">biases</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">bias</span><span class="p">))</span>
662666

663667
<span class="k">if</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">):</span>
664668
<span class="n">factor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">filters</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
665669
<span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">zeros_</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span>
666-
<span class="bp">self</span><span class="o">.</span><span class="n">register_parameter</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;_factor</span><span class="si">{</span><span class="n">i</span><span class="si">:</span><span class="s2">d</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">factor</span><span class="p">))</span>
670+
<span class="bp">self</span><span class="o">.</span><span class="n">factors</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">factor</span><span class="p">))</span>
667671

668672
<span class="bp">self</span><span class="o">.</span><span class="n">quantiles</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
669673
<span class="n">init</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">init_scale</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_scale</span><span class="p">])</span>
@@ -723,24 +727,23 @@ <h1>Source code for compressai.entropy_models.entropy_models</h1><div class="hig
723727
<span class="c1"># TorchScript not yet working (nn.Mmodule indexing not supported)</span>
724728
<span class="n">logits</span> <span class="o">=</span> <span class="n">inputs</span>
725729
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
726-
<span class="n">matrix</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;_matrix</span><span class="si">{</span><span class="n">i</span><span class="si">:</span><span class="s2">d</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
730+
<span class="n">matrix</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">matrices</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
727731
<span class="k">if</span> <span class="n">stop_gradient</span><span class="p">:</span>
728732
<span class="n">matrix</span> <span class="o">=</span> <span class="n">matrix</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
729733
<span class="n">logits</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">F</span><span class="o">.</span><span class="n">softplus</span><span class="p">(</span><span class="n">matrix</span><span class="p">),</span> <span class="n">logits</span><span class="p">)</span>
730734

731-
<span class="n">bias</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;_bias</span><span class="si">{</span><span class="n">i</span><span class="si">:</span><span class="s2">d</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
735+
<span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">biases</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
732736
<span class="k">if</span> <span class="n">stop_gradient</span><span class="p">:</span>
733737
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
734-
<span class="n">logits</span> <span class="o">+=</span> <span class="n">bias</span>
738+
<span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span> <span class="o">+</span> <span class="n">bias</span>
735739

736740
<span class="k">if</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">):</span>
737-
<span class="n">factor</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;_factor</span><span class="si">{</span><span class="n">i</span><span class="si">:</span><span class="s2">d</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
741+
<span class="n">factor</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">factors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
738742
<span class="k">if</span> <span class="n">stop_gradient</span><span class="p">:</span>
739743
<span class="n">factor</span> <span class="o">=</span> <span class="n">factor</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
740-
<span class="n">logits</span> <span class="o">+=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span>
744+
<span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span>
741745
<span class="k">return</span> <span class="n">logits</span>
742746

743-
<span class="nd">@torch</span><span class="o">.</span><span class="n">jit</span><span class="o">.</span><span class="n">unused</span>
744747
<span class="k">def</span> <span class="nf">_likelihood</span><span class="p">(</span>
745748
<span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">stop_gradient</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
746749
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
@@ -758,10 +761,13 @@ <h1>Source code for compressai.entropy_models.entropy_models</h1><div class="hig
758761

759762
<span class="k">if</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">jit</span><span class="o">.</span><span class="n">is_scripting</span><span class="p">():</span>
760763
<span class="c1"># x from B x C x ... to C x B x ...</span>
761-
<span class="n">perm</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span>
762-
<span class="n">perm</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">perm</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">perm</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">perm</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
763-
<span class="c1"># Compute inverse permutation</span>
764-
<span class="n">inv_perm</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">))[</span><span class="n">np</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">perm</span><span class="p">)]</span>
764+
<span class="n">perm</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
765+
<span class="p">(</span>
766+
<span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
767+
<span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
768+
<span class="p">)</span>
769+
<span class="p">)</span>
770+
<span class="n">inv_perm</span> <span class="o">=</span> <span class="n">perm</span>
765771
<span class="k">else</span><span class="p">:</span>
766772
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span>
767773
<span class="c1"># TorchScript in 2D for static inference</span>

_modules/compressai/models/base.html

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ <h1>Source code for compressai.models.base</h1><div class="highlight"><pre>
327327

328328
<span class="kn">from</span> <span class="nn">compressai.entropy_models</span> <span class="kn">import</span> <span class="n">EntropyBottleneck</span><span class="p">,</span> <span class="n">GaussianConditional</span>
329329
<span class="kn">from</span> <span class="nn">compressai.latent_codecs</span> <span class="kn">import</span> <span class="n">LatentCodec</span>
330-
<span class="kn">from</span> <span class="nn">compressai.models.utils</span> <span class="kn">import</span> <span class="n">update_registered_buffers</span>
330+
<span class="kn">from</span> <span class="nn">compressai.models.utils</span> <span class="kn">import</span> <span class="n">remap_old_keys</span><span class="p">,</span> <span class="n">update_registered_buffers</span>
331331

332332
<span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span>
333333
<span class="s2">&quot;CompressionModel&quot;</span><span class="p">,</span>
@@ -395,6 +395,7 @@ <h1>Source code for compressai.models.base</h1><div class="highlight"><pre>
395395
<span class="p">[</span><span class="s2">&quot;_quantized_cdf&quot;</span><span class="p">,</span> <span class="s2">&quot;_offset&quot;</span><span class="p">,</span> <span class="s2">&quot;_cdf_length&quot;</span><span class="p">],</span>
396396
<span class="n">state_dict</span><span class="p">,</span>
397397
<span class="p">)</span>
398+
<span class="n">state_dict</span> <span class="o">=</span> <span class="n">remap_old_keys</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">)</span>
398399

399400
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">GaussianConditional</span><span class="p">):</span>
400401
<span class="n">update_registered_buffers</span><span class="p">(</span>

0 commit comments

Comments
 (0)