|
1 | 1 | from copy import deepcopy |
2 | 2 | from dataclasses import dataclass |
3 | | -from typing import List, Literal, Optional, Union |
| 3 | +from typing import List, Literal, Optional |
4 | 4 |
|
5 | 5 | from pydantic import BaseModel, Field, field_validator, validate_call |
6 | 6 | from vonage_http_client.http_client import HttpClient |
7 | | -from vonage_sms.errors import SmsError |
| 7 | + |
| 8 | +from .errors import SmsError, PartialFailureError |
8 | 9 |
|
9 | 10 |
|
10 | 11 | class SmsMessage(BaseModel): |
11 | 12 | to: str |
12 | 13 | from_: str = Field(..., alias="from") |
13 | 14 | text: str |
14 | | - type: Optional[str] = None |
15 | 15 | sig: Optional[str] = Field(None, min_length=16, max_length=60) |
16 | | - status_report_req: Optional[int] = Field( |
17 | | - None, |
18 | | - alias="status-report-req", |
19 | | - description="Set to 1 to receive a Delivery Receipt", |
20 | | - ) |
21 | | - client_ref: Optional[str] = Field( |
22 | | - None, alias="client-ref", description="Your own reference. Up to 40 characters." |
23 | | - ) |
24 | | - network_code: Optional[str] = Field( |
25 | | - None, |
26 | | - alias="network-code", |
27 | | - description="A 4-5 digit number that represents the mobile carrier network code", |
28 | | - ) |
| 16 | + client_ref: Optional[str] = Field(None, alias="client-ref", max_length=100) |
| 17 | + type: Optional[Literal['text', 'binary', 'unicode']] = None |
| 18 | + ttl: Optional[int] = Field(None, ge=20000, le=604800000) |
| 19 | + status_report_req: Optional[bool] = Field(None, alias='status-report-req') |
| 20 | + callback: Optional[str] = Field(None, max_length=100) |
| 21 | + message_class: Optional[int] = Field(None, alias='message-class', ge=0, le=3) |
| 22 | + body: Optional[str] = None |
| 23 | + udh: Optional[str] = None |
| 24 | + protocol_id: Optional[int] = Field(None, alias='protocol-id', ge=0, le=255) |
| 25 | + account_ref: Optional[str] = Field(None, alias='account-ref') |
| 26 | + entity_id: Optional[str] = Field(None, alias='entity-id') |
| 27 | + content_id: Optional[str] = Field(None, alias='content-id') |
| 28 | + |
| 29 | + @field_validator('body', 'udh') |
| 30 | + @classmethod |
| 31 | + def validate_body(cls, value, values): |
| 32 | + if 'type' not in values or not values['type'] == 'binary': |
| 33 | + raise ValueError( |
| 34 | + 'This parameter can only be set when the "type" parameter is set to "binary".' |
| 35 | + ) |
| 36 | + if values['type'] == 'binary' and not value: |
| 37 | + raise ValueError('This parameter is required for binary messages.') |
| 38 | + |
| 39 | + |
| 40 | +@dataclass |
| 41 | +class MessageResponse: |
| 42 | + to: str |
| 43 | + message_id: str |
| 44 | + status: str |
| 45 | + remaining_balance: str |
| 46 | + message_price: str |
| 47 | + network: str |
| 48 | + client_ref: Optional[str] = None |
| 49 | + account_ref: Optional[str] = None |
29 | 50 |
|
30 | 51 |
|
31 | 52 | @dataclass |
32 | 53 | class SmsResponse: |
33 | | - id: str |
| 54 | + message_count: str |
| 55 | + messages: List[MessageResponse] |
34 | 56 |
|
35 | 57 |
|
36 | 58 | class Sms: |
37 | 59 | """Calls Vonage's SMS API.""" |
38 | 60 |
|
39 | 61 | def __init__(self, http_client: HttpClient) -> None: |
40 | 62 | self._http_client = deepcopy(http_client) |
41 | | - self._auth_type = 'basic' |
| 63 | + if self._http_client._auth._signature_secret: |
| 64 | + self._auth_type = 'signature' |
| 65 | + else: |
| 66 | + self._auth_type = 'basic' |
42 | 67 |
|
43 | 68 | @validate_call |
44 | 69 | def send(self, message: SmsMessage) -> SmsResponse: |
45 | 70 | """Send an SMS message.""" |
46 | 71 | response = self._http_client.post( |
47 | | - self._http_client.api_host, |
48 | | - '/v2/ni', |
49 | | - message.model_dump(), |
| 72 | + self._http_client.rest_host, |
| 73 | + '/sms/json', |
| 74 | + message.model_dump(by_alias=True), |
50 | 75 | self._auth_type, |
51 | 76 | ) |
| 77 | + |
| 78 | + if int(response['message-count']) > 1: |
| 79 | + self.check_for_partial_failure(response) |
| 80 | + else: |
| 81 | + self.check_for_error(response) |
| 82 | + |
| 83 | + messages = [] |
| 84 | + for message in response['messages']: |
| 85 | + messages.append(MessageResponse(**message)) |
| 86 | + |
| 87 | + return SmsResponse(message_count=response['message-count'], messages=messages) |
| 88 | + |
| 89 | + def check_for_partial_failure(self, response_data): |
| 90 | + successful_messages = 0 |
| 91 | + total_messages = int(response_data['message-count']) |
| 92 | + |
| 93 | + for message in response_data['messages']: |
| 94 | + if message['status'] == '0': |
| 95 | + successful_messages += 1 |
| 96 | + if successful_messages < total_messages: |
| 97 | + raise PartialFailureError(response_data) |
| 98 | + |
| 99 | + def check_for_error(self, response_data): |
| 100 | + message = response_data['messages'][0] |
| 101 | + if int(message['status']) != 0: |
| 102 | + raise SmsError( |
| 103 | + f'Sms.send_message method failed with error code {message["status"]}: {message["error-text"]}' |
| 104 | + ) |
0 commit comments