|
| 1 | +import pytest |
| 2 | +from pydantic import BaseModel |
| 3 | + |
| 4 | +from common.asyncapi import ( |
| 5 | + get_schema, |
| 6 | + init_asyncapi_info, |
| 7 | + register_channel, |
| 8 | + register_channel_operation, |
| 9 | + register_server, |
| 10 | +) |
| 11 | + |
| 12 | + |
| 13 | +# Test fixtures |
| 14 | +@pytest.fixture |
| 15 | +def reset_asyncapi_state(): |
| 16 | + """Reset all global state between tests""" |
| 17 | + from common.asyncapi import _channels, _components_schemas, _operations, _servers |
| 18 | + |
| 19 | + _servers.clear() |
| 20 | + _channels.clear() |
| 21 | + _operations.clear() |
| 22 | + _components_schemas.clear() |
| 23 | + yield |
| 24 | + _servers.clear() |
| 25 | + _channels.clear() |
| 26 | + _operations.clear() |
| 27 | + _components_schemas.clear() |
| 28 | + |
| 29 | + |
| 30 | +# Test message models |
| 31 | +class TestMessage(BaseModel): |
| 32 | + content: str |
| 33 | + timestamp: int |
| 34 | + |
| 35 | + |
| 36 | +class AnotherTestMessage(BaseModel): |
| 37 | + status: bool |
| 38 | + code: int |
| 39 | + |
| 40 | + |
| 41 | +# Test cases |
| 42 | +def test_init_asyncapi_info(): |
| 43 | + """Test initialization of AsyncAPI info""" |
| 44 | + title = "Test API" |
| 45 | + version = "2.0.0" |
| 46 | + |
| 47 | + init_asyncapi_info(title=title, version=version) |
| 48 | + schema = get_schema() |
| 49 | + |
| 50 | + assert schema.info.title == title |
| 51 | + assert schema.info.version == version |
| 52 | + |
| 53 | + |
| 54 | +def test_register_server(reset_asyncapi_state): |
| 55 | + """Test server registration""" |
| 56 | + server_id = "test-server" |
| 57 | + host = "localhost" |
| 58 | + protocol = "ws" |
| 59 | + pathname = "/ws" |
| 60 | + |
| 61 | + register_server(id=server_id, host=host, protocol=protocol, pathname=pathname) |
| 62 | + |
| 63 | + schema = get_schema() |
| 64 | + assert server_id in schema.servers |
| 65 | + assert schema.servers[server_id].host == host |
| 66 | + assert schema.servers[server_id].protocol == protocol |
| 67 | + assert schema.servers[server_id].pathname == pathname |
| 68 | + |
| 69 | + |
| 70 | +def test_register_channel(reset_asyncapi_state): |
| 71 | + """Test channel registration""" |
| 72 | + channel_id = "test-channel" |
| 73 | + address = "test/topic" |
| 74 | + description = "Test channel" |
| 75 | + title = "Test Channel" |
| 76 | + |
| 77 | + register_channel(address=address, id=channel_id, description=description, title=title) |
| 78 | + |
| 79 | + schema = get_schema() |
| 80 | + assert channel_id in schema.channels |
| 81 | + assert schema.channels[channel_id].address == address |
| 82 | + assert schema.channels[channel_id].description == description |
| 83 | + assert schema.channels[channel_id].title == title |
| 84 | + |
| 85 | + |
| 86 | +def test_register_channel_with_server(reset_asyncapi_state): |
| 87 | + """Test channel registration with server reference""" |
| 88 | + server_id = "test-server" |
| 89 | + channel_id = "test-channel" |
| 90 | + |
| 91 | + register_server(id=server_id, host="localhost", protocol="ws") |
| 92 | + register_channel(address="test/topic", id=channel_id, server_id=server_id) |
| 93 | + |
| 94 | + schema = get_schema() |
| 95 | + assert len(schema.channels[channel_id].servers) == 1 |
| 96 | + assert schema.channels[channel_id].servers[0].ref == f"#/servers/{server_id}" |
| 97 | + |
| 98 | + |
| 99 | +def test_register_channel_operation(reset_asyncapi_state): |
| 100 | + """Test channel operation registration""" |
| 101 | + channel_id = "test-channel" |
| 102 | + operation_type = "receive" |
| 103 | + |
| 104 | + register_channel(address="test/topic", id=channel_id) |
| 105 | + register_channel_operation( |
| 106 | + channel_id=channel_id, operation_type=operation_type, messages=[TestMessage], operation_name="test-operation" |
| 107 | + ) |
| 108 | + |
| 109 | + schema = get_schema() |
| 110 | + assert "test-operation" in schema.operations |
| 111 | + assert schema.operations["test-operation"].action == operation_type |
| 112 | + assert schema.operations["test-operation"].channel.ref == f"#/channels/{channel_id}" |
| 113 | + assert TestMessage.__name__ in schema.components.schemas |
| 114 | + |
| 115 | + |
| 116 | +def test_register_channel_operation_invalid_channel(reset_asyncapi_state): |
| 117 | + """Test channel operation registration with invalid channel""" |
| 118 | + with pytest.raises(ValueError, match="Channel non-existent does not exist"): |
| 119 | + register_channel_operation(channel_id="non-existent", operation_type="receive", messages=[TestMessage]) |
| 120 | + |
| 121 | + |
| 122 | +def test_multiple_messages_registration(reset_asyncapi_state): |
| 123 | + """Test registration of multiple messages for an operation""" |
| 124 | + channel_id = "test-channel" |
| 125 | + |
| 126 | + register_channel(address="test/topic", id=channel_id) |
| 127 | + register_channel_operation(channel_id=channel_id, operation_type="send", messages=[TestMessage, AnotherTestMessage]) |
| 128 | + |
| 129 | + schema = get_schema() |
| 130 | + assert TestMessage.__name__ in schema.components.schemas |
| 131 | + assert AnotherTestMessage.__name__ in schema.components.schemas |
| 132 | + assert TestMessage.__name__ in schema.channels[channel_id].messages |
| 133 | + assert AnotherTestMessage.__name__ in schema.channels[channel_id].messages |
0 commit comments