Skip to content

Commit 17ba5fe

Browse files
authored
Fix type handling of polymorphic arguments (#70)
* Refactor type handling to use isinstance() instead of type() This makes it support the weird edge of str being subclassed in Home Assistant * Throw an exception when an invalid argument type is provided instead of continuing with the class code path
1 parent 44dff5e commit 17ba5fe

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

openrgb/orgb.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -294,23 +294,25 @@ def set_colors(self, colors: list[utils.RGBColor], fast: bool = False):
294294
elif active_mode.color_mode == utils.ModeColors.PER_LED:
295295
self._set_device_colors(colors, fast)
296296

297-
def set_mode(self, mode: Union[int, str, utils.ModeData], save: bool = False):
297+
def set_mode(self, mode: Union[str, int, utils.ModeData], save: bool = False):
298298
'''
299299
Sets the device's mode
300300
301-
:param mode: the id, name, or the ModeData object itself to set as the mode
301+
:param mode: the name, id, or the ModeData object itself to set as the mode
302302
'''
303-
if type(mode) == utils.ModeData:
304-
pass
305-
elif type(mode) == int:
306-
mode = self.modes[mode]
307-
elif type(mode) == str:
303+
if isinstance(mode, str):
308304
try:
309305
mode = next(
310306
(m for m in self.modes if m.name.lower() == mode.lower()))
311307
except StopIteration as e:
312308
raise ValueError(
313309
f"Mode `{mode}` not found for device `{self.name}`") from e
310+
elif isinstance(mode, int):
311+
mode = self.modes[mode]
312+
elif isinstance(mode, utils.ModeData):
313+
pass
314+
else:
315+
raise TypeError()
314316
data = mode.pack(self.comms._protocol_version) # type: ignore
315317
self.comms.send_header(
316318
self.id,
@@ -454,7 +456,7 @@ def load_profile(self, name: Union[str, int, utils.Profile], local: bool = False
454456
:param directory: what directory the profile is in. Defaults to OpenRGB's config directory for supported OS's (Windows or Linux), or falls back to using the current working directory.
455457
'''
456458
if local:
457-
assert type(name) is str
459+
assert isinstance(name, str)
458460
if directory == '':
459461
if platform.system() == "Linux":
460462
directory = environ['HOME'].rstrip(
@@ -482,17 +484,19 @@ def load_profile(self, name: Union[str, int, utils.Profile], local: bool = False
482484
if new_controller.active_mode != device.active_mode:
483485
device.set_mode(new_controller.active_mode)
484486
else:
485-
if type(name) is str:
487+
if isinstance(name, str):
486488
try:
487489
name = next(
488490
p for p in self.profiles if p.name.lower() == name.lower())
489491
except StopIteration as e:
490492
raise ValueError(
491493
f"`{name}` is not an existing profile") from e
492-
elif type(name) is int:
494+
elif isinstance(name, int):
493495
name = self.profiles[name]
494-
elif type(name) is utils.Profile:
496+
elif isinstance(name, utils.Profile):
495497
pass
498+
else:
499+
raise TypeError()
496500
raw_name = name.pack() # type: ignore
497501
self.comms.send_header(
498502
0, utils.PacketType.REQUEST_LOAD_PROFILE, len(raw_name))
@@ -521,16 +525,18 @@ def save_profile(self, name: Union[str, int, utils.Profile], local: bool = False
521525
f.write(utils.LocalProfile(
522526
[dev.data for dev in self.devices]).pack())
523527
else:
524-
if type(name) is str:
528+
if isinstance(name, str):
525529
try:
526530
name = next(
527531
p for p in self.profiles if p.name.lower() == name.lower())
528532
except StopIteration:
529533
name = utils.Profile(name) # type: ignore
530-
elif type(name) is int:
534+
elif isinstance(name, int):
531535
name = self.profiles[name]
532-
elif type(name) is utils.Profile:
536+
elif isinstance(name, utils.Profile):
533537
pass
538+
else:
539+
raise TypeError()
534540
raw_name = name.pack() # type: ignore
535541
self.comms.send_header(
536542
0, utils.PacketType.REQUEST_SAVE_PROFILE, len(raw_name))
@@ -543,16 +549,18 @@ def delete_profile(self, name: Union[str, int, utils.Profile]):
543549
544550
:param name: Can be a profile's name, index, or even the Profile itself
545551
'''
546-
if type(name) is str:
552+
if isinstance(name, str):
547553
try:
548554
name = next(
549555
p for p in self.profiles if p.name.lower() == name.lower())
550556
except StopIteration as e:
551557
raise ValueError(f"`{name}` is not an existing profile") from e
552-
elif type(name) is int:
558+
elif isinstance(name, int):
553559
name = self.profiles[name]
554-
elif type(name) is utils.Profile:
560+
elif isinstance(name, utils.Profile):
555561
pass
562+
else:
563+
raise TypeError()
556564
raw_name = name.pack() # type: ignore
557565
self.comms.send_header(
558566
0, utils.PacketType.REQUEST_DELETE_PROFILE, len(raw_name))

0 commit comments

Comments
 (0)