Skip to content

Commit 0076aaf

Browse files
authored
Move color_extractor services to separate module (home-assistant#158341)
1 parent c50f4d6 commit 0076aaf

File tree

3 files changed

+168
-146
lines changed

3 files changed

+168
-146
lines changed

homeassistant/components/color_extractor/__init__.py

Lines changed: 5 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,157 +1,19 @@
11
"""Module for color_extractor (RGB extraction from images) component."""
22

3-
import asyncio
4-
import io
5-
import logging
6-
7-
import aiohttp
8-
from colorthief import ColorThief
9-
from PIL import UnidentifiedImageError
10-
import voluptuous as vol
11-
12-
from homeassistant.components.light import (
13-
ATTR_RGB_COLOR,
14-
DOMAIN as LIGHT_DOMAIN,
15-
LIGHT_TURN_ON_SCHEMA,
16-
)
173
from homeassistant.config_entries import ConfigEntry
18-
from homeassistant.const import SERVICE_TURN_ON as LIGHT_SERVICE_TURN_ON
19-
from homeassistant.core import HomeAssistant, ServiceCall
20-
from homeassistant.helpers import aiohttp_client, config_validation as cv
4+
from homeassistant.core import HomeAssistant
5+
from homeassistant.helpers import config_validation as cv
216
from homeassistant.helpers.typing import ConfigType
227

23-
from .const import ATTR_PATH, ATTR_URL, DOMAIN, SERVICE_TURN_ON
24-
25-
_LOGGER = logging.getLogger(__name__)
8+
from .const import DOMAIN
9+
from .services import async_setup_services
2610

2711
CONFIG_SCHEMA = cv.removed(DOMAIN, raise_if_present=False)
2812

29-
# Extend the existing light.turn_on service schema
30-
SERVICE_SCHEMA = vol.All(
31-
cv.has_at_least_one_key(ATTR_URL, ATTR_PATH),
32-
cv.make_entity_service_schema(
33-
{
34-
**LIGHT_TURN_ON_SCHEMA,
35-
vol.Exclusive(ATTR_PATH, "color_extractor"): cv.isfile,
36-
vol.Exclusive(ATTR_URL, "color_extractor"): cv.url,
37-
}
38-
),
39-
)
40-
41-
42-
def _get_file(file_path):
43-
"""Get a PIL acceptable input file reference.
44-
45-
Allows us to mock patch during testing to make BytesIO stream.
46-
"""
47-
return file_path
48-
49-
50-
def _get_color(file_handler) -> tuple:
51-
"""Given an image file, extract the predominant color from it."""
52-
color_thief = ColorThief(file_handler)
53-
54-
# get_color returns a SINGLE RGB value for the given image
55-
color = color_thief.get_color(quality=1)
56-
57-
_LOGGER.debug("Extracted RGB color %s from image", color)
58-
59-
return color
60-
6113

6214
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
6315
"""Set up the Color extractor component."""
64-
65-
async def async_handle_service(service_call: ServiceCall) -> None:
66-
"""Decide which color_extractor method to call based on service."""
67-
service_data = dict(service_call.data)
68-
69-
try:
70-
if ATTR_URL in service_data:
71-
image_type = "URL"
72-
image_reference = service_data.pop(ATTR_URL)
73-
color = await async_extract_color_from_url(image_reference)
74-
75-
elif ATTR_PATH in service_data:
76-
image_type = "file path"
77-
image_reference = service_data.pop(ATTR_PATH)
78-
color = await hass.async_add_executor_job(
79-
extract_color_from_path, image_reference
80-
)
81-
82-
except UnidentifiedImageError as ex:
83-
_LOGGER.error(
84-
"Bad image from %s '%s' provided, are you sure it's an image? %s",
85-
image_type,
86-
image_reference,
87-
ex,
88-
)
89-
return
90-
91-
if color:
92-
service_data[ATTR_RGB_COLOR] = color
93-
94-
await hass.services.async_call(
95-
LIGHT_DOMAIN, LIGHT_SERVICE_TURN_ON, service_data, blocking=True
96-
)
97-
98-
hass.services.async_register(
99-
DOMAIN,
100-
SERVICE_TURN_ON,
101-
async_handle_service,
102-
schema=SERVICE_SCHEMA,
103-
)
104-
105-
async def async_extract_color_from_url(url):
106-
"""Handle call for URL based image."""
107-
if not hass.config.is_allowed_external_url(url):
108-
_LOGGER.error(
109-
(
110-
"External URL '%s' is not allowed, please add to"
111-
" 'allowlist_external_urls'"
112-
),
113-
url,
114-
)
115-
return None
116-
117-
_LOGGER.debug("Getting predominant RGB from image URL '%s'", url)
118-
119-
# Download the image into a buffer for ColorThief to check against
120-
try:
121-
session = aiohttp_client.async_get_clientsession(hass)
122-
123-
async with asyncio.timeout(10):
124-
response = await session.get(url)
125-
126-
except (TimeoutError, aiohttp.ClientError) as err:
127-
_LOGGER.error("Failed to get ColorThief image due to HTTPError: %s", err)
128-
return None
129-
130-
content = await response.content.read()
131-
132-
with io.BytesIO(content) as _file:
133-
_file.name = "color_extractor.jpg"
134-
_file.seek(0)
135-
136-
return _get_color(_file)
137-
138-
def extract_color_from_path(file_path):
139-
"""Handle call for local file based image."""
140-
if not hass.config.is_allowed_path(file_path):
141-
_LOGGER.error(
142-
(
143-
"File path '%s' is not allowed, please add to"
144-
" 'allowlist_external_dirs'"
145-
),
146-
file_path,
147-
)
148-
return None
149-
150-
_LOGGER.debug("Getting predominant RGB from file path '%s'", file_path)
151-
152-
_file = _get_file(file_path)
153-
return _get_color(_file)
154-
16+
async_setup_services(hass)
15517
return True
15618

15719

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""Module for color_extractor (RGB extraction from images) component."""
2+
3+
import asyncio
4+
import io
5+
import logging
6+
7+
import aiohttp
8+
from colorthief import ColorThief
9+
from PIL import UnidentifiedImageError
10+
import voluptuous as vol
11+
12+
from homeassistant.components.light import (
13+
ATTR_RGB_COLOR,
14+
DOMAIN as LIGHT_DOMAIN,
15+
LIGHT_TURN_ON_SCHEMA,
16+
)
17+
from homeassistant.const import SERVICE_TURN_ON as LIGHT_SERVICE_TURN_ON
18+
from homeassistant.core import HomeAssistant, ServiceCall, callback
19+
from homeassistant.helpers import aiohttp_client, config_validation as cv
20+
21+
from .const import ATTR_PATH, ATTR_URL, DOMAIN, SERVICE_TURN_ON
22+
23+
_LOGGER = logging.getLogger(__name__)
24+
25+
# Extend the existing light.turn_on service schema
26+
SERVICE_SCHEMA = vol.All(
27+
cv.has_at_least_one_key(ATTR_URL, ATTR_PATH),
28+
cv.make_entity_service_schema(
29+
{
30+
**LIGHT_TURN_ON_SCHEMA,
31+
vol.Exclusive(ATTR_PATH, "color_extractor"): cv.isfile,
32+
vol.Exclusive(ATTR_URL, "color_extractor"): cv.url,
33+
}
34+
),
35+
)
36+
37+
38+
def _get_file(file_path: str) -> str:
39+
"""Get a PIL acceptable input file reference.
40+
41+
Allows us to mock patch during testing to make BytesIO stream.
42+
"""
43+
return file_path
44+
45+
46+
def _get_color(file_handler: io.BytesIO | str) -> tuple[int, int, int]:
47+
"""Given an image file, extract the predominant color from it."""
48+
color_thief = ColorThief(file_handler)
49+
50+
# get_color returns a SINGLE RGB value for the given image
51+
color = color_thief.get_color(quality=1)
52+
53+
_LOGGER.debug("Extracted RGB color %s from image", color)
54+
55+
return color
56+
57+
58+
async def _async_extract_color_from_url(
59+
hass: HomeAssistant, url: str
60+
) -> tuple[int, int, int] | None:
61+
"""Handle call for URL based image."""
62+
if not hass.config.is_allowed_external_url(url):
63+
_LOGGER.error(
64+
(
65+
"External URL '%s' is not allowed, please add to"
66+
" 'allowlist_external_urls'"
67+
),
68+
url,
69+
)
70+
return None
71+
72+
_LOGGER.debug("Getting predominant RGB from image URL '%s'", url)
73+
74+
# Download the image into a buffer for ColorThief to check against
75+
try:
76+
session = aiohttp_client.async_get_clientsession(hass)
77+
78+
async with asyncio.timeout(10):
79+
response = await session.get(url)
80+
81+
except (TimeoutError, aiohttp.ClientError) as err:
82+
_LOGGER.error("Failed to get ColorThief image due to HTTPError: %s", err)
83+
return None
84+
85+
content = await response.content.read()
86+
87+
with io.BytesIO(content) as _file:
88+
_file.name = "color_extractor.jpg"
89+
_file.seek(0)
90+
91+
return _get_color(_file)
92+
93+
94+
def _extract_color_from_path(
95+
hass: HomeAssistant, file_path: str
96+
) -> tuple[int, int, int] | None:
97+
"""Handle call for local file based image."""
98+
if not hass.config.is_allowed_path(file_path):
99+
_LOGGER.error(
100+
"File path '%s' is not allowed, please add to 'allowlist_external_dirs'",
101+
file_path,
102+
)
103+
return None
104+
105+
_LOGGER.debug("Getting predominant RGB from file path '%s'", file_path)
106+
107+
_file = _get_file(file_path)
108+
return _get_color(_file)
109+
110+
111+
async def async_handle_service(service_call: ServiceCall) -> None:
112+
"""Decide which color_extractor method to call based on service."""
113+
service_data = dict(service_call.data)
114+
115+
try:
116+
if ATTR_URL in service_data:
117+
image_type = "URL"
118+
image_reference = service_data.pop(ATTR_URL)
119+
color = await _async_extract_color_from_url(
120+
service_call.hass, image_reference
121+
)
122+
123+
elif ATTR_PATH in service_data:
124+
image_type = "file path"
125+
image_reference = service_data.pop(ATTR_PATH)
126+
color = await service_call.hass.async_add_executor_job(
127+
_extract_color_from_path, service_call.hass, image_reference
128+
)
129+
130+
except UnidentifiedImageError as ex:
131+
_LOGGER.error(
132+
"Bad image from %s '%s' provided, are you sure it's an image? %s",
133+
image_type,
134+
image_reference,
135+
ex,
136+
)
137+
return
138+
139+
if color:
140+
service_data[ATTR_RGB_COLOR] = color
141+
142+
await service_call.hass.services.async_call(
143+
LIGHT_DOMAIN, LIGHT_SERVICE_TURN_ON, service_data, blocking=True
144+
)
145+
146+
147+
@callback
148+
def async_setup_services(hass: HomeAssistant) -> None:
149+
"""Register the services."""
150+
151+
hass.services.async_register(
152+
DOMAIN,
153+
SERVICE_TURN_ON,
154+
async_handle_service,
155+
schema=SERVICE_SCHEMA,
156+
)

tests/components/color_extractor/test_services.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010
from voluptuous.error import MultipleInvalid
1111

12-
from homeassistant.components.color_extractor import (
12+
from homeassistant.components.color_extractor.services import (
1313
ATTR_PATH,
1414
ATTR_URL,
1515
DOMAIN,
@@ -270,7 +270,9 @@ async def test_file(hass: HomeAssistant, setup_integration) -> None:
270270
assert state.state == STATE_OFF
271271

272272
# Mock the file handler read with our 1x1 base64 encoded fixture image
273-
with patch("homeassistant.components.color_extractor._get_file", _get_file_mock):
273+
with patch(
274+
"homeassistant.components.color_extractor.services._get_file", _get_file_mock
275+
):
274276
await hass.services.async_call(DOMAIN, SERVICE_TURN_ON, service_data)
275277
await hass.async_block_till_done()
276278

@@ -305,7 +307,9 @@ async def test_file_denied_dir(hass: HomeAssistant, setup_integration) -> None:
305307
assert state.state == STATE_OFF
306308

307309
# Mock the file handler read with our 1x1 base64 encoded fixture image
308-
with patch("homeassistant.components.color_extractor._get_file", _get_file_mock):
310+
with patch(
311+
"homeassistant.components.color_extractor.services._get_file", _get_file_mock
312+
):
309313
await hass.services.async_call(DOMAIN, SERVICE_TURN_ON, service_data)
310314
await hass.async_block_till_done()
311315

0 commit comments

Comments
 (0)