|
2 | 2 | DARTS Advanced
|
3 | 3 | ==============
|
4 | 4 |
|
5 |
| - |
6 |
| -Differentiable architecture search |
7 |
| - |
8 |
| -This is an adaptation of Hanxiao Liu et al's DARTS algorithm, extending |
9 |
| -the work to handle convolutional neural networks for NLP problems and more. |
10 |
| -Details of the original authors' approach can be found in their 2019 ICLR paper_. |
11 |
| - |
12 |
| -DARTS works by composing various neural net primitives, defined as Pytorch *nn.Modules*, |
13 |
| -to create a larger directed acyclic graph (DAG) that is to be your model. This |
14 |
| -composition is differentiable as we take the softmax of the choice of primitive types |
15 |
| -at each layer of the network. To make this more clear, let's first define a few abstractions |
16 |
| -in the algorithm: |
17 |
| - |
18 |
| -1. **Primitve**: this is the fundamental block of computation, defined as an *nn.Module*. |
19 |
| - At each layer of your network, one of these primitves will be chosen by taking the |
20 |
| - softmax of all possible primitives at that layer. Examples could be a convolution block, |
21 |
| - a linear layer, a skip connect, or anything that you can come up with (subject to a few |
22 |
| - constraints). |
23 |
| - |
24 |
| -2. **Cell**: this is an abstraction that holds each of the primitive types for level of your |
25 |
| - network. This is where we perform the softmax over the possible primitive types. |
26 |
| - |
27 |
| -3. **Nodes**: this is the level of abstraction that would normally be considered a layer in |
28 |
| - your network. It can contain one or more *Cells*. |
29 |
| - |
30 |
| -4. **Architecture**: The abstraction that contains all nodes in the graph. This computes a |
31 |
| - Hessian product with respect to the *alpha* parameters as defined in the paper. |
32 |
| - |
33 |
| -5. **Genotype**: genotypes are instances of a particular configuration of the graph. As the |
34 |
| - optimization runs, and each cell computes the softmax over their primitive types, the final |
35 |
| - configuration of all nodes with their resulting primitive is a genotype. |
36 |
| - |
37 |
| -In the DARTS algorithm, we define a number of primitives that we would like to compose together |
38 |
| -to form our neural network. The original paper started with 8 primitive types. These types |
39 |
| -were originally designed for a vision task, and largely consist of convolution type operations. |
40 |
| -We have since adapted these types for the *P3B5* benchmark, creating 1D convolution types for |
41 |
| -our NLP tasks. If you would like to see how these primitives are defined, along with their |
42 |
| -necessary constructors used by DARTS, you can find them in |
43 |
| -`darts.modules.operations.conv.py`_. |
44 |
| - |
45 |
| -These primitives are then contained within a cell, and one or more cells are contained within a |
46 |
| -node in the graph. DARTS then works by composing these nodes together and taking the softmax over |
47 |
| -their primitives in each cell. Finally, the *Architecture* abstraction contains all nodes, and is |
48 |
| -responsible for differentiating the composition of the nodes with respect to two *alpha* parameters |
49 |
| -as defined in the paper. The end result is that we have a differentiable model that composes its |
50 |
| -components as the model is training. |
51 |
| - |
52 |
| -As the optimization runs, the model will print the resulting loss with respect to a given *Genotype*. |
53 |
| -The final model will be the *Genotype* with corresponding to the lowest loss. |
54 |
| - |
55 |
| -Adnvanced Example |
56 |
| ------------------ |
57 |
| - |
58 | 5 | In this example we will take a look at how to define our own primitives to be handled by DARTS. If
|
59 | 6 | you have not read the `Uno example`_, I would recommend taking a look at that first. There we showed
|
60 | 7 | how we can use the built in primitives to DARTS. As reference, you can also look to see how those
|
@@ -172,7 +119,16 @@ of the primitives must have the same number of input and output features, this w
|
172 | 119 | of features from any of your primitives. Since DARTS cannot know ahead of time what your primitives will be,
|
173 | 120 | we must specify how many features will go into our final fully connected layer of the network.
|
174 | 121 |
|
175 |
| -Finally, to run this example: |
| 122 | +Run the Example |
| 123 | +--------------- |
| 124 | + |
| 125 | +First, make sure that you can get the example data by installing `torchvision`: |
| 126 | + |
| 127 | +.. code-block:: |
| 128 | +
|
| 129 | + pip install torchvision |
| 130 | +
|
| 131 | +Then run the example with |
176 | 132 |
|
177 | 133 | .. code-block::
|
178 | 134 |
|
|
0 commit comments