Skip to content

Commit a063aee

Browse files
committed
[ILUVATAR_GPU] Support for iluvatar_gpu
1 parent 98f82a5 commit a063aee

File tree

5 files changed

+29
-2
lines changed

5 files changed

+29
-2
lines changed

ppocr/data/imaug/operators.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np
2525
import math
2626
from PIL import Image
27+
from paddle import get_device
2728

2829

2930
class DecodeImage(object):
@@ -231,6 +232,8 @@ def __call__(self, data):
231232
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
232233
data["image"] = img
233234
data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
235+
if "iluvatar_gpu" in get_device():
236+
data["shape"] = data["shape"].astype(np.float32)
234237
return data
235238

236239
def image_padding(self, im, value=0):

ppocr/data/imaug/random_crop_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np
2525
import cv2
2626
import random
27+
from paddle import get_device
2728

2829

2930
def is_poly_in_rect(poly, x, y, w, h):
@@ -179,6 +180,8 @@ def __call__(self, data):
179180
texts_crop.append(text)
180181
data["image"] = img
181182
data["polys"] = np.array(text_polys_crop)
183+
if "iluvatar_gpu" in get_device():
184+
data["polys"] = np.array(text_polys_crop).astype(np.float32)
182185
data["ignore_tags"] = ignore_tags_crop
183186
data["texts"] = texts_crop
184187
return data

ppocr/data/imaug/rec_img_aug.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
SVTRDeterioration,
2929
ParseQDeterioration,
3030
)
31+
from paddle import get_device
3132
from paddle.vision.transforms import Compose
3233

3334

@@ -305,6 +306,8 @@ def __call__(self, data):
305306
norm_img, valid_ratio = resize_norm_img(img, self.image_shape, self.padding)
306307
data["image"] = norm_img
307308
data["valid_ratio"] = valid_ratio
309+
if "iluvatar_gpu" in get_device():
310+
data["valid_ratio"] = np.float32(valid_ratio)
308311
return data
309312

310313

@@ -338,6 +341,8 @@ def __call__(self, data):
338341

339342
data["image"] = norm_img
340343
data["valid_ratio"] = valid_ratio
344+
if "iluvatar_gpu" in get_device():
345+
data["valid_ratio"] = np.float32(valid_ratio)
341346
return data
342347

343348

@@ -366,6 +371,8 @@ def __call__(self, data):
366371
)
367372
data["image"] = norm_img
368373
data["valid_ratio"] = valid_ratio
374+
if "iluvatar_gpu" in get_device():
375+
data["valid_ratio"] = np.float32(valid_ratio)
369376
return data
370377

371378

@@ -407,6 +414,8 @@ def __call__(self, data):
407414
data["resized_shape"] = resize_shape
408415
data["pad_shape"] = pad_shape
409416
data["valid_ratio"] = valid_ratio
417+
if "iluvatar_gpu" in get_device():
418+
data["valid_ratio"] = np.float32(valid_ratio)
410419
return data
411420

412421

@@ -539,6 +548,8 @@ def __call__(self, data):
539548
norm_img, valid_ratio = resize_norm_img_abinet(img, self.image_shape)
540549
data["image"] = norm_img
541550
data["valid_ratio"] = valid_ratio
551+
if "iluvatar_gpu" in get_device():
552+
data["valid_ratio"] = np.float32(valid_ratio)
542553
return data
543554

544555

@@ -553,6 +564,8 @@ def __call__(self, data):
553564
norm_img, valid_ratio = resize_norm_img(img, self.image_shape, self.padding)
554565
data["image"] = norm_img
555566
data["valid_ratio"] = valid_ratio
567+
if "iluvatar_gpu" in get_device():
568+
data["valid_ratio"] = np.float32(valid_ratio)
556569
return data
557570

558571

@@ -574,6 +587,8 @@ def __call__(self, data):
574587
data["resized_shape"] = resize_shape
575588
data["pad_shape"] = pad_shape
576589
data["valid_ratio"] = valid_ratio
590+
if "iluvatar_gpu" in get_device():
591+
data["valid_ratio"] = np.float32(valid_ratio)
577592
data["word_positons"] = word_positons
578593
return data
579594

ppocr/data/simple_dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import traceback
2121
from paddle.io import Dataset
2222
from .imaug import transform, create_operators
23+
from paddle import get_device
2324

2425

2526
class SimpleDataSet(Dataset):
@@ -203,6 +204,8 @@ def resize_norm_img(self, data, imgW, imgH, padding=True):
203204
valid_ratio = min(1.0, float(resized_w / imgW))
204205
data["image"] = padding_im
205206
data["valid_ratio"] = valid_ratio
207+
if "iluvatar_gpu" in get_device():
208+
data["valid_ratio"] = np.float32(valid_ratio)
206209
return data
207210

208211
def __getitem__(self, properties):

tools/program.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def merge_config(config, opts):
115115
return config
116116

117117

118-
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False, use_gcu=False):
118+
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False, use_gcu=False, use_iluvatar_gpu=False):
119119
"""
120120
Log error and exit when set use_gpu=true in paddlepaddle
121121
cpu version.
@@ -833,6 +833,7 @@ def preprocess(is_train=False):
833833
use_npu = config["Global"].get("use_npu", False)
834834
use_mlu = config["Global"].get("use_mlu", False)
835835
use_gcu = config["Global"].get("use_gcu", False)
836+
use_iluvatar_gpu = config["Global"].get("use_iluvatar_gpu", False)
836837

837838
alg = config["Architecture"]["algorithm"]
838839
assert alg in [
@@ -896,9 +897,11 @@ def preprocess(is_train=False):
896897
device = "mlu:{0}".format(os.getenv("FLAGS_selected_mlus", 0))
897898
elif use_gcu: # Use Enflame GCU(General Compute Unit)
898899
device = "gcu:{0}".format(os.getenv("FLAGS_selected_gcus", 0))
900+
elif use_iluvatar_gpu:
901+
device = "iluvatar_gpu:{0}".format(dist.ParallelEnv().dev_id)
899902
else:
900903
device = "gpu:{}".format(dist.ParallelEnv().dev_id) if use_gpu else "cpu"
901-
check_device(use_gpu, use_xpu, use_npu, use_mlu, use_gcu)
904+
check_device(use_gpu, use_xpu, use_npu, use_mlu, use_gcu, use_iluvatar_gpu)
902905

903906
device = paddle.set_device(device)
904907

0 commit comments

Comments
 (0)