Skip to content

Commit a3b493f

Browse files
committed
make style
1 parent d52e973 commit a3b493f

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

examples/research_projects/anytext/auxiliary_latent_module.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,14 @@ def retrieve_latents(
2929

3030
class AuxiliaryLatentModule(nn.Module):
3131
def __init__(
32-
self, font_path, glyph_channels=1, position_channels=1, model_channels=320, vae=None, device="cpu", use_fp16=False
32+
self,
33+
font_path,
34+
glyph_channels=1,
35+
position_channels=1,
36+
model_channels=320,
37+
vae=None,
38+
device="cpu",
39+
use_fp16=False,
3340
):
3441
super().__init__()
3542
self.font = ImageFont.truetype(font_path, 60)
@@ -78,12 +85,8 @@ def __init__(
7885

7986
self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1)
8087

81-
self.glyph_block.load_state_dict(
82-
load_file("glyph_block.safetensors", device=str(self.device))
83-
)
84-
self.position_block.load_state_dict(
85-
load_file("position_block.safetensors", device=str(self.device))
86-
)
88+
self.glyph_block.load_state_dict(load_file("glyph_block.safetensors", device=str(self.device)))
89+
self.position_block.load_state_dict(load_file("position_block.safetensors", device=str(self.device)))
8790
self.fuse_block.load_state_dict(load_file("fuse_block.safetensors", device=str(self.device)))
8891

8992
if use_fp16:

examples/research_projects/anytext/ocr_recog/RecModel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def load_3rd_state_dict(self, _3rd_name, _state):
3535

3636
def forward(self, x):
3737
import torch
38+
3839
x = x.to(torch.float32)
3940
x = self.backbone(x)
4041
x = self.neck(x)

examples/research_projects/anytext/recognizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch.nn.functional as F
1414
from easydict import EasyDict as edict
1515
from ocr_recog.RecModel import RecModel
16-
from safetensors.torch import load_file
1716
from skimage.transform._geometric import _umeyama as get_sym_mat
1817

1918

0 commit comments

Comments
 (0)