| 
11 | 11 | 
 
  | 
12 | 12 |     from tests.unit.sdk.conftest import BothClients  | 
13 | 13 | 
 
  | 
 | 14 | +client_types = ["standard", "sync"]  | 
14 | 15 | 
 
  | 
 | 16 | + | 
 | 17 | +@pytest.mark.parametrize("client_type", client_types)  | 
 | 18 | +async def test_batch_execution(clients: BothClients, client_type: str):  | 
 | 19 | +    r: list[int] = []  | 
 | 20 | +    tasks_number = 10  | 
 | 21 | + | 
 | 22 | +    if client_type == "standard":  | 
 | 23 | + | 
 | 24 | +        async def test_func() -> int:  | 
 | 25 | +            return 1  | 
 | 26 | + | 
 | 27 | +        batch = await clients.standard.create_batch()  | 
 | 28 | +        for _ in range(tasks_number):  | 
 | 29 | +            batch.add(task=test_func)  | 
 | 30 | + | 
 | 31 | +        assert batch.num_tasks == tasks_number  | 
 | 32 | +        async for _, result in batch.execute():  | 
 | 33 | +            r.append(result)  | 
 | 34 | +    else:  | 
 | 35 | + | 
 | 36 | +        def test_func() -> int:  | 
 | 37 | +            return 1  | 
 | 38 | + | 
 | 39 | +        batch = clients.sync.create_batch()  | 
 | 40 | +        for _ in range(tasks_number):  | 
 | 41 | +            batch.add(task=test_func)  | 
 | 42 | + | 
 | 43 | +        assert batch.num_tasks == tasks_number  | 
 | 44 | +        for _, result in batch.execute():  | 
 | 45 | +            r.append(result)  | 
 | 46 | + | 
 | 47 | +    assert r == [1] * tasks_number  | 
 | 48 | + | 
 | 49 | + | 
 | 50 | +@pytest.mark.parametrize("client_type", client_types)  | 
15 | 51 | async def test_batch_return_exception(  | 
16 |  | -    httpx_mock: HTTPXMock, mock_query_mutation_location_create_failed, mock_schema_query_01, clients: BothClients  | 
 | 52 | +    httpx_mock: HTTPXMock,  | 
 | 53 | +    mock_query_mutation_location_create_failed,  | 
 | 54 | +    mock_schema_query_01,  | 
 | 55 | +    clients: BothClients,  | 
 | 56 | +    client_type: str,  | 
17 | 57 | ):  # pylint: disable=unused-argument  | 
18 |  | -    batch = await clients.standard.create_batch(return_exceptions=True)  | 
19 |  | -    locations = ["JFK1", "JFK1"]  | 
20 |  | -    results = []  | 
21 |  | -    for location_name in locations:  | 
22 |  | -        data = {"name": {"value": location_name, "is_protected": True}}  | 
23 |  | -        obj = await clients.standard.create(kind="BuiltinLocation", data=data)  | 
24 |  | -        batch.add(task=obj.save, node=obj)  | 
25 |  | -        results.append(obj)  | 
26 |  | - | 
27 |  | -    result_iter = batch.execute()  | 
28 |  | -    # Assert first node success  | 
29 |  | -    node, result = await result_iter.__anext__()  | 
30 |  | -    assert node == results[0]  | 
31 |  | -    assert not isinstance(result, Exception)  | 
32 |  | - | 
33 |  | -    # Assert second node failure  | 
34 |  | -    node, result = await result_iter.__anext__()  | 
35 |  | -    assert node == results[1]  | 
36 |  | -    assert isinstance(result, GraphQLError)  | 
37 |  | -    assert "An error occurred while executing the GraphQL Query" in str(result)  | 
 | 58 | +    if client_type == "standard":  | 
 | 59 | +        batch = await clients.standard.create_batch(return_exceptions=True)  | 
 | 60 | +        locations = ["JFK1", "JFK1"]  | 
 | 61 | +        results = []  | 
 | 62 | +        for location_name in locations:  | 
 | 63 | +            data = {"name": {"value": location_name, "is_protected": True}}  | 
 | 64 | +            obj = await clients.standard.create(kind="BuiltinLocation", data=data)  | 
 | 65 | +            batch.add(task=obj.save, node=obj)  | 
 | 66 | +            results.append(obj)  | 
 | 67 | + | 
 | 68 | +        result_iter = batch.execute()  | 
 | 69 | +        # Assert first node success  | 
 | 70 | +        node, result = await result_iter.__anext__()  | 
 | 71 | +        assert node == results[0]  | 
 | 72 | +        assert not isinstance(result, Exception)  | 
38 | 73 | 
 
  | 
 | 74 | +        # Assert second node failure  | 
 | 75 | +        node, result = await result_iter.__anext__()  | 
 | 76 | +        assert node == results[1]  | 
 | 77 | +        assert isinstance(result, GraphQLError)  | 
 | 78 | +        assert "An error occurred while executing the GraphQL Query" in str(result)  | 
 | 79 | +    else:  | 
 | 80 | +        batch = clients.sync.create_batch(return_exceptions=True)  | 
 | 81 | +        locations = ["JFK1", "JFK1"]  | 
 | 82 | +        results = []  | 
 | 83 | +        for location_name in locations:  | 
 | 84 | +            data = {"name": {"value": location_name, "is_protected": True}}  | 
 | 85 | +            obj = clients.sync.create(kind="BuiltinLocation", data=data)  | 
 | 86 | +            batch.add(task=obj.save, node=obj)  | 
 | 87 | +            results.append(obj)  | 
39 | 88 | 
 
  | 
 | 89 | +        results = [r for _, r in batch.execute()]  | 
 | 90 | +        # Must have one exception and one graphqlerror  | 
 | 91 | +        assert len(results) == 2  | 
 | 92 | +        assert any(isinstance(r, Exception) for r in results)  | 
 | 93 | +        assert any(isinstance(r, GraphQLError) for r in results)  | 
 | 94 | + | 
 | 95 | + | 
 | 96 | +@pytest.mark.parametrize("client_type", client_types)  | 
40 | 97 | async def test_batch_exception(  | 
41 |  | -    httpx_mock: HTTPXMock, mock_query_mutation_location_create_failed, mock_schema_query_01, clients: BothClients  | 
 | 98 | +    httpx_mock: HTTPXMock,  | 
 | 99 | +    mock_query_mutation_location_create_failed,  | 
 | 100 | +    mock_schema_query_01,  | 
 | 101 | +    clients: BothClients,  | 
 | 102 | +    client_type: str,  | 
42 | 103 | ):  # pylint: disable=unused-argument  | 
43 |  | -    batch = await clients.standard.create_batch(return_exceptions=False)  | 
44 |  | -    locations = ["JFK1", "JFK1"]  | 
45 |  | -    for location_name in locations:  | 
46 |  | -        data = {"name": {"value": location_name, "is_protected": True}}  | 
47 |  | -        obj = await clients.standard.create(kind="BuiltinLocation", data=data)  | 
48 |  | -        batch.add(task=obj.save, node=obj)  | 
49 |  | - | 
50 |  | -    with pytest.raises(GraphQLError) as exc:  | 
51 |  | -        async for _, _ in batch.execute():  | 
52 |  | -            pass  | 
53 |  | -    assert "An error occurred while executing the GraphQL Query" in str(exc.value)  | 
54 |  | - | 
55 |  | - | 
56 |  | -async def test_batch_not_implemented_sync(clients: BothClients):  | 
57 |  | -    with pytest.raises(NotImplementedError):  | 
58 |  | -        clients.sync.create_batch()  | 
 | 104 | +    if client_type == "standard":  | 
 | 105 | +        batch = await clients.standard.create_batch(return_exceptions=False)  | 
 | 106 | +        locations = ["JFK1", "JFK1"]  | 
 | 107 | +        for location_name in locations:  | 
 | 108 | +            data = {"name": {"value": location_name, "is_protected": True}}  | 
 | 109 | +            obj = await clients.standard.create(kind="BuiltinLocation", data=data)  | 
 | 110 | +            batch.add(task=obj.save, node=obj)  | 
 | 111 | + | 
 | 112 | +        with pytest.raises(GraphQLError) as exc:  | 
 | 113 | +            async for _, _ in batch.execute():  | 
 | 114 | +                pass  | 
 | 115 | +        assert "An error occurred while executing the GraphQL Query" in str(exc.value)  | 
 | 116 | +    else:  | 
 | 117 | +        batch = clients.sync.create_batch(return_exceptions=False)  | 
 | 118 | +        locations = ["JFK1", "JFK1"]  | 
 | 119 | +        for location_name in locations:  | 
 | 120 | +            data = {"name": {"value": location_name, "is_protected": True}}  | 
 | 121 | +            obj = clients.sync.create(kind="BuiltinLocation", data=data)  | 
 | 122 | +            batch.add(task=obj.save, node=obj)  | 
 | 123 | + | 
 | 124 | +        with pytest.raises(GraphQLError) as exc:  | 
 | 125 | +            for _, _ in batch.execute():  | 
 | 126 | +                pass  | 
 | 127 | +        assert "An error occurred while executing the GraphQL Query" in str(exc.value)  | 
0 commit comments