|
1 | 1 | """ |
2 | 2 | This script is used when converting models from PyTorch to TF. |
3 | 3 | """ |
| 4 | +import logging |
| 5 | + |
4 | 6 | import numpy as np |
5 | 7 | import tensorflow as tf |
6 | 8 | import timm |
|
10 | 12 | import tfimm # noqa: F401 |
11 | 13 | from tfimm.utils.timm import load_pytorch_weights_in_tf2_model # noqa: F401 |
12 | 14 |
|
13 | | -model_name = "resnet18" |
| 15 | +logging.basicConfig(level=logging.INFO) |
| 16 | + |
| 17 | +model_name = "efficientnet_b0" |
| 18 | +pt_model_name = "tf_efficientnet_b0" |
14 | 19 |
|
15 | 20 | # We need to test models in both training and inference mode (BN) |
16 | 21 | training = False |
17 | 22 | nb_calls = 3 |
18 | 23 |
|
19 | 24 | # Load PyTorch model |
20 | | -pt_model = timm.create_model(model_name, pretrained=True) |
| 25 | +pt_model = timm.create_model( |
| 26 | + pt_model_name, pretrained=True, drop_rate=0.0, drop_path_rate=0.0 |
| 27 | +) |
21 | 28 | # If a model is not part of the `timm` library, we can load the state dict directly |
22 | 29 | # state_dict = load_state_dict_from_url( |
23 | 30 | # url="https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar" # noqa: E501 |
|
35 | 42 | if not training: # Set PyTorch model to inference mode |
36 | 43 | pt_model.eval() |
37 | 44 |
|
| 45 | +# Create test input |
| 46 | +img = np.random.rand(5, 224, 224, 3).astype("float32") |
| 47 | + |
| 48 | +# Run inference for PyTorch model |
| 49 | +pt_img = torch.Tensor(img.transpose([0, 3, 1, 2])) |
| 50 | +if training: |
| 51 | + for _ in range(nb_calls): |
| 52 | + _ = pt_model.forward(pt_img) |
| 53 | +pt_res = pt_model.forward(pt_img) |
| 54 | +pt_res = pt_res.detach().numpy() |
| 55 | +# When we look at output of intermediate layers, we have to transpose PyTorch data |
| 56 | +# format (NCHW) to TF data format (NHWC). We don't have to do this, if we only look |
| 57 | +# at the final logits |
| 58 | +# pt_res = pt_res.transpose([0, 2, 3, 1]) |
| 59 | +print(pt_res.shape) |
| 60 | + |
38 | 61 | # Load TF model |
39 | | -tf_model = tfimm.create_model(model_name, pretrained="timm") |
| 62 | +tf_model = tfimm.create_model( |
| 63 | + model_name, pretrained=True, drop_rate=0.0, drop_path_rate=0.0 |
| 64 | +) |
40 | 65 | # If we want to load the weights from a pytorch model outside the model factory: |
41 | 66 | # load_pytorch_weights_in_tf2_model(tf_model, pt_model.state_dict()) |
42 | 67 | # For debug purposes we may want to print variable names |
43 | 68 | # for w in tf_model.weights: |
44 | 69 | # print(w.name) |
45 | 70 |
|
46 | | -# Create test input |
47 | | -img = np.random.rand(5, 224, 224, 3).astype("float32") |
48 | | - |
49 | 71 | # Run inference for TF model |
50 | 72 | tf_img = tf.constant(img) |
51 | 73 | if training: # If training we do multiple forward passes to test BN param updates |
|
59 | 81 | tf_res = tf_res.numpy() |
60 | 82 | print(tf_res.shape) |
61 | 83 |
|
62 | | -# Run inference for PyTorch model |
63 | | -pt_img = torch.Tensor(img.transpose([0, 3, 1, 2])) |
64 | | -if training: |
65 | | - for _ in range(nb_calls): |
66 | | - _ = pt_model.forward(pt_img) |
67 | | -pt_res = pt_model.forward(pt_img) |
68 | | -pt_res = pt_res.detach().numpy() |
69 | | -# When we look at output of intermediate layers, we have to transpose PyTorch data |
70 | | -# format (NCHW) to TF data format (NHWC). We don't have to do this, if we only look |
71 | | -# at the final logits |
72 | | -# pt_res = pt_res.transpose([0, 2, 3, 1]) |
73 | | -print(pt_res.shape) |
74 | | - |
75 | 84 | # Compare outputs between PyTorch and Tensorflow. We should expect the relative error |
76 | 85 | # to be <1e-5. It won't be much lower, because TF and PyTorch implement BN slightly |
77 | 86 | # differently. The two formulas are mathematically, but not numerically equivalent. |
|
0 commit comments