Skip to content

Commit d9aac49

Browse files
authored
add pydantic support to rtc.VideoFrame & rtc.AudioFrame (#348)
1 parent ea9205f commit d9aac49

File tree

4 files changed

+113
-3
lines changed

4 files changed

+113
-3
lines changed

.github/workflows/check-types.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
run: python -m pip install --upgrade mypy
3030

3131
- name: Install packages
32-
run: python -m pip install pytest ./livekit-api ./livekit-protocol ./livekit-rtc
32+
run: python -m pip install pytest ./livekit-api ./livekit-protocol ./livekit-rtc pydantic
3333

3434
- name: Check Types
3535
run: python -m mypy --install-type --non-interactive -p 'livekit-protocol' -p 'livekit-api' -p 'livekit-rtc'

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@ jobs:
2323
- name: Run tests
2424
run: |
2525
python3 ./livekit-rtc/rust-sdks/download_ffi.py --output livekit-rtc/livekit/rtc/resources
26-
pip3 install pytest ./livekit-protocol ./livekit-api ./livekit-rtc
26+
pip3 install pytest ./livekit-protocol ./livekit-api ./livekit-rtc pydantic
2727
pytest . --ignore=livekit-rtc/rust-sdks

livekit-rtc/livekit/rtc/audio_frame.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ._proto import audio_frame_pb2 as proto_audio
1818
from ._proto import ffi_pb2 as proto_ffi
1919
from ._utils import get_address
20-
from typing import Union
20+
from typing import Any, Union
2121

2222

2323
class AudioFrame:
@@ -55,6 +55,10 @@ def __init__(
5555
"data length must be >= num_channels * samples_per_channel * sizeof(int16)"
5656
)
5757

58+
if len(data) % ctypes.sizeof(ctypes.c_int16) != 0:
59+
# can happen if data is bigger than needed
60+
raise ValueError("data length must be a multiple of sizeof(int16)")
61+
5862
self._data = bytearray(data)
5963
self._sample_rate = sample_rate
6064
self._num_channels = num_channels
@@ -197,3 +201,58 @@ def __repr__(self) -> str:
197201
f"samples_per_channel={self.samples_per_channel}, "
198202
f"duration={self.duration:.3f})"
199203
)
204+
205+
@classmethod
206+
def __get_pydantic_core_schema__(cls, *_: Any):
207+
from pydantic_core import core_schema
208+
import base64
209+
210+
def validate_audio_frame(value: Any) -> "AudioFrame":
211+
if isinstance(value, AudioFrame):
212+
return value
213+
214+
if isinstance(value, tuple):
215+
value = value[0]
216+
217+
if isinstance(value, dict):
218+
return AudioFrame(
219+
data=base64.b64decode(value["data"]),
220+
sample_rate=value["sample_rate"],
221+
num_channels=value["num_channels"],
222+
samples_per_channel=value["samples_per_channel"],
223+
)
224+
225+
raise TypeError("Invalid type for AudioFrame")
226+
227+
return core_schema.json_or_python_schema(
228+
json_schema=core_schema.chain_schema(
229+
[
230+
core_schema.model_fields_schema(
231+
{
232+
"data": core_schema.model_field(core_schema.str_schema()),
233+
"sample_rate": core_schema.model_field(
234+
core_schema.int_schema()
235+
),
236+
"num_channels": core_schema.model_field(
237+
core_schema.int_schema()
238+
),
239+
"samples_per_channel": core_schema.model_field(
240+
core_schema.int_schema()
241+
),
242+
},
243+
),
244+
core_schema.no_info_plain_validator_function(validate_audio_frame),
245+
]
246+
),
247+
python_schema=core_schema.no_info_plain_validator_function(
248+
validate_audio_frame
249+
),
250+
serialization=core_schema.plain_serializer_function_ser_schema(
251+
lambda instance: {
252+
"data": base64.b64encode(instance.data).decode("utf-8"),
253+
"sample_rate": instance.sample_rate,
254+
"num_channels": instance.num_channels,
255+
"samples_per_channel": instance.samples_per_channel,
256+
}
257+
),
258+
)

livekit-rtc/livekit/rtc/video_frame.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from ._ffi_client import FfiClient, FfiHandle
2121
from ._utils import get_address
2222

23+
from typing import Any
24+
2325

2426
class VideoFrame:
2527
"""
@@ -203,6 +205,55 @@ def convert(
203205
def __repr__(self) -> str:
204206
return f"rtc.VideoFrame(width={self.width}, height={self.height}, type={self.type})"
205207

208+
@classmethod
209+
def __get_pydantic_core_schema__(cls, *_: Any):
210+
from pydantic_core import core_schema
211+
import base64
212+
213+
def validate_video_frame(value: Any) -> "VideoFrame":
214+
if isinstance(value, VideoFrame):
215+
return value
216+
217+
if isinstance(value, tuple):
218+
value = value[0]
219+
220+
if isinstance(value, dict):
221+
return VideoFrame(
222+
width=value["width"],
223+
height=value["height"],
224+
type=proto_video.VideoBufferType.ValueType(value["type"]),
225+
data=base64.b64decode(value["data"]),
226+
)
227+
228+
raise TypeError("Invalid type for VideoFrame")
229+
230+
return core_schema.json_or_python_schema(
231+
json_schema=core_schema.chain_schema(
232+
[
233+
core_schema.model_fields_schema(
234+
{
235+
"width": core_schema.model_field(core_schema.int_schema()),
236+
"height": core_schema.model_field(core_schema.int_schema()),
237+
"type": core_schema.model_field(core_schema.int_schema()),
238+
"data": core_schema.model_field(core_schema.str_schema()),
239+
},
240+
),
241+
core_schema.no_info_plain_validator_function(validate_video_frame),
242+
]
243+
),
244+
python_schema=core_schema.no_info_plain_validator_function(
245+
validate_video_frame
246+
),
247+
serialization=core_schema.plain_serializer_function_ser_schema(
248+
lambda instance: {
249+
"width": instance.width,
250+
"height": instance.height,
251+
"type": instance.type,
252+
"data": base64.b64encode(instance.data).decode("utf-8"),
253+
}
254+
),
255+
)
256+
206257

207258
def _component_info(
208259
data_ptr: int, stride: int, size: int

0 commit comments

Comments
 (0)