Skip to content

Commit 5c1a0f6

Browse files
committed
2 parents 64cd5a7 + 18cef8b commit 5c1a0f6

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

project-5/proj5.html

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,9 +1020,9 @@ <h4>Limitations on pure noise</h4>
10201020
To generate plausible-looking digits, we need a different approach than one-step denoising.
10211021

10221022
<h4>The Flow Matching Model</h4>
1023-
Instead of trying to denoise the image in a single step, we aim to iteratively denoise the image, similar to how we do so in the sampling loops using DeepFloyd's noise coefficients. To do this, we will start by interpolating how intermediate noise samples are constructed. The simplest approrach is to use linear interpolation, namely let the intermediate sample be <code>x<sub>t</sub></code> = (1 - <code>t</code>)x<sub>0</sub> + <code>tx<sub>1</sub></code> for a given <code>t</code> &isin; [0, 1], where <code>x<sub>0</sub></code> is the noise and <code>x<sub>1</sub></code> is the clean image.<br>
1023+
Instead of trying to denoise the image in a single step, we aim to iteratively denoise the image. To do this, we will start by interpolating how intermediate noise samples are constructed. The simplest approrach is to use linear interpolation, namely let the intermediate sample be <code>x<sub>t</sub></code> = (1 - <code>t</code>)x<sub>0</sub> + <code>tx<sub>1</sub></code> for a given <code>t</code> &isin; [0, 1], where <code>x<sub>0</sub></code> is the noise and <code>x<sub>1</sub></code> is the clean image.<br>
10241024

1025-
<br>Now that we have an equation relating a clean image with any pure noise sample, we can train our model to learn the <strong>flow</strong>, or the change with respect to <code>t</code> for any given <code>x<sub>t</sub></code>. This produces a vector field across all images, where the velocity for each is d/dt <code>x<sub>t</sub></code> = <code>x<sub>1</sub></code> - <code>x<sub>0</sub></code>. Therefore, if we can predict <code>x<sub>1</sub></code> - <code>x<sub>0</sub></code> for any given <code>t</code> and <code>x<sub>t</sub></code>, we can go along the path traced out by the vector field and arrived at somewhere near the manifold of clean images. This technique is known as a <strong>flow matching model</strong>, and with the model trained, we can numerically integrate a random noise sample <code>x<sub>0</sub></code> with a set number of iterations using Euler's method, and get a clean image <code>x<sub>1</sub></code>.
1025+
<br>Now that we have an equation relating a clean image with any pure noise sample, we can train our model to learn the <strong>flow</strong>, or the change with respect to <code>t</code> for any given <code>x<sub>t</sub></code>. This produces a vector field across all images, where the velocity at time <code>t</code> for each is d/dt <code>x<sub>t</sub></code> = <code>x<sub>1</sub></code> - <code>x<sub>0</sub></code>. Therefore, if we can predict <code>x<sub>1</sub></code> - <code>x<sub>0</sub></code> for any given <code>t</code> and <code>x<sub>t</sub></code>, we can go along the path traced out by the vector field and arrived at somewhere near the manifold of clean images. This technique is known as a <strong>flow matching model</strong>, and with the model trained, we can numerically integrate a random noise sample <code>x<sub>0</sub></code> with a set number of iterations using Euler's method, and get a clean image <code>x<sub>1</sub></code>.
10261026

10271027
<h4>Training a Time-Conditioned UNet</h4>
10281028
To add time conditioning to our UNet, we will make the following changes to our model architecture:
@@ -1034,9 +1034,14 @@ <h4>Training a Time-Conditioned UNet</h4>
10341034
</div>
10351035

10361036
<h4>Flow Matching Hyperparameters</h4>
1037-
For the hyperparameters, we will be using a batch size of 64, a learning rate of <code>1e-2</code>, a hidden dimension of 64, the Adam optimizer with the given learning rate, a exponential learning rate decay scheduler with &gamma; = 0.1<sup>(1.0 / <code>num_epochs</code>)</sup>, a sampling iteration count of <code>T</code> = 50, and a training time of 10 epochs. To advance the scheduler, we will call <code>scheduler.step()</code> at the end of each training epoch.
1037+
For the hyperparameters, we will be using a batch size of 64, a learning rate of <code>1e-2</code>, a hidden dimension of 64, the Adam optimizer with the given learning rate, a exponential learning rate decay scheduler with &gamma; = 0.1<sup>(1.0 / <code>num_epochs</code>)</sup>, a sampling iteration count of <code>T</code> = 300, and a training time of 10 epochs. To advance the scheduler, we will call <code>scheduler.step()</code> at the end of each training epoch.
10381038

1039-
<h4>Forward and Sampling Operations</h4>
1039+
<h4>Embedding <code>t</code> in the UNet</h4>
1040+
To embed <code>t</code> in the UNet, we will multiply the <code>unflat</code> and <code>firstUpBlock</code> tensors (the result after applying the <strong>Unflatten</strong> and the first <strong>UpConv</strong> operations respectively) by <code>fc1_t</code> and <code>fc2_t</code>. <code>fc1_t</code> and <code>fc2_t</code> are the result of passing <code>t</code> through the first and second FCBlock, where the first produces a tensor with twice the number of hidden dimensions, while the second has the same number of hidden dimensions as the first and last ConvBlock (i.e. the first and second result each has 2D and D channels). In pesudocode:
1041+
<pre><code>unflat_cond = unflat * fc1_t
1042+
firstUpBlock_cond = firstUpBlock * fc2_t</code></pre>
1043+
1044+
<h4>Time-Conditioned Forward and Sampling Operations</h4>
10401045
To train our model, for each clean image <code>x<sub>1</sub></code> we will generate <code>x<sub>0</sub></code> &isin; &Nscr;(0, &#119816;) and <code>t</code> &isin; U([0, 1]), where U is the uniform distribution. After computing <code>x<sub>t</sub></code> = (1 - <code>t</code>)x<sub>0</sub> + <code>tx<sub>1</sub></code>, we will feed <code>x<sub>t</sub></code> and <code>t</code> into our UNet and compute the loss of u<sub>&theta;</sub>(<code>x<sub>t</sub></code>, <code>t</code>) and <code>x<sub>1</sub></code> - <code>x<sub>0</sub></code>. Below is the new model's training loss curve:
10411046
<div align="center">
10421047
<figure>
@@ -1054,9 +1059,34 @@ <h4>Forward and Sampling Operations</h4>
10541059
Although the results are not perfect, the improvements starting from the 1st epoch up to the 10th are already noticeable.
10551060

10561061
<h4>Adding Class-Conditioning to Time-Conditioned UNet</h4>
1057-
To make more improvements to our image generation, we can condition our UNet on the class of digits 0-9. This require adding an additional FCBlock for each time condition, where the class vector <code>c</code> is a one-hot vector.
1062+
To make more improvements to our image generation, we can condition our UNet on the class of digits 0-9. This require adding an additional FCBlock for each time condition, where the class vector <code>c</code> is a one-hot vector. To ensure that the UNet would still work without conditioning on the class (in order to implement CFG later), we will set a dropout rate <code>p<sub>uncond</sub></code> of 0.1, in which we set the one-hot vector of <code>c</code> to all 0s.
1063+
1064+
<h4>Embedding <code>c</code> and <code>t</code> in the UNet</h4>
1065+
To embed <code>c</code> <code>t</code> in the UNet, we will use 2 additional FCBlocks to convert the label (<code>c</code>) into 2 tensors <code>fc1_c</code> and <code>fc2_c</code>, each with the same number of hidden dimensions as <code>fc1_t</code> and <code>fc2_t</code> respectively. Then, instead of multiplying the intermediate blocks by the time tensor, we will instead do:
1066+
<pre><code>unflat_cond_class = unflat * fc1_c + fc1_t
1067+
firstUpBlock_cond = firstUpBlock * fc2_c + fc2_t</code></pre>
1068+
1069+
The last step is to zero out the class one-hot vectors at the dropout rate, which we can implement efficiently by using a mask of the same length as the batch size. We can than multiply it with the batch of one-hot vectors to zero out any vector that is the <code>i</code>-th in the batch if <code>mask[i] = 0</code>.
1070+
1071+
<h4>Class Conditioning Hyperparameters</h4>
1072+
Because class conditioning converges fast, we will use the same number of training epochs as time conditioning, which is 10. A guidance scale of &gamma; = 5 will be used in the CFG part. The same hyperparamters as the Time-Conditioned UNet will be used for the relevant parts.
10581073

10591074
<h4>Class-Conditioned Forward and Sampling Operations</h4>
1075+
The forward will be very similar to the Time-Conditioned UNet, except to compute the loss, we will also input the training image's label into the model, along with a mask of 1s and 0s with 0 probability <code>p<sub>uncond</sub></code>. The training loss curve is as follows:
1076+
<div align="center">
1077+
<figure>
1078+
<img src="images/unet/26_training_curve.png" alt="26_visualization.png" />
1079+
</figure>
1080+
</div>
1081+
1082+
For the sampling operation, we will compute the unconditional estimate <code>u<sub>uncond</sub></code> in velcoity using a mask of all 0s in the model, as well as the conditional estimate <code>u<sub>cond</sub></code> of a given digit using a mask of all 1s. Once we have <code>u<sub>uncond</sub></code> and <code>u<sub>uncond</sub></code>, we will let our final estimate be <code>u<sub>cfg</sub></code> = <code>u<sub>uncond</sub></code> + &gamma;(<code>u<sub>cond</sub></code> - <code>u<sub>uncond</sub></code>) before updating each iteration with <code>x<sub>0</sub></code> = (1 / <code>T</code>)<code>u<sub>cfg</sub></code>. Below is a visualization of generating the digits 0-9 4 times for epochs 1, 5, and 10:
1083+
<div align="center">
1084+
<figure>
1085+
<img src="images/unet/26_visualization.png" alt="26_visualization.png" />
1086+
</figure>
1087+
</div>
1088+
1089+
Even compared to the Time-Conditioned UNet, the improvements are clear and sizable.
10601090
</section>
10611091

10621092
<!-- ========================================================= -->

0 commit comments

Comments
 (0)