Skip to content

Commit e75d8ae

Browse files
authored
vitstr: remove second input (#3662)
1 parent 336bae9 commit e75d8ae

File tree

1 file changed

+13
-3
lines changed
  • models/public/vitstr-small-patch16-224

1 file changed

+13
-3
lines changed

models/public/vitstr-small-patch16-224/model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from torch import load
15+
import torch
1616
from modules.vitstr import vitstr_small_patch16_224
1717

1818

19+
class Model(torch.nn.Module):
20+
def __init__(self, model):
21+
super(Model, self).__init__()
22+
self.model = model
23+
24+
def forward(self, x):
25+
return self.model(x)
26+
27+
28+
1929
def create_model(weights):
2030
model = vitstr_small_patch16_224(num_classes=96)
2131

22-
checkpoint = load(weights, map_location='cpu')
32+
checkpoint = torch.load(weights, map_location='cpu')
2333
ckpt = {k.replace('module.vitstr.', ''): v for k, v in checkpoint.items()}
2434

2535
model.load_state_dict(ckpt)
2636

27-
return model
37+
return Model(model)

0 commit comments

Comments
 (0)