|  | 
|  | 1 | +/* | 
|  | 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | 
|  | 3 | + * or more contributor license agreements. Licensed under the Elastic License | 
|  | 4 | + * 2.0; you may not use this file except in compliance with the Elastic License | 
|  | 5 | + * 2.0. | 
|  | 6 | + */ | 
|  | 7 | + | 
|  | 8 | +package org.elasticsearch.xpack.inference.action.filter; | 
|  | 9 | + | 
|  | 10 | +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; | 
|  | 11 | + | 
|  | 12 | +import org.elasticsearch.ElasticsearchSecurityException; | 
|  | 13 | +import org.elasticsearch.action.bulk.BulkItemResponse; | 
|  | 14 | +import org.elasticsearch.action.bulk.BulkRequestBuilder; | 
|  | 15 | +import org.elasticsearch.action.bulk.BulkResponse; | 
|  | 16 | +import org.elasticsearch.action.index.IndexRequestBuilder; | 
|  | 17 | +import org.elasticsearch.cluster.metadata.IndexMetadata; | 
|  | 18 | +import org.elasticsearch.common.settings.Settings; | 
|  | 19 | +import org.elasticsearch.core.Strings; | 
|  | 20 | +import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; | 
|  | 21 | +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; | 
|  | 22 | +import org.elasticsearch.inference.SimilarityMeasure; | 
|  | 23 | +import org.elasticsearch.license.LicenseSettings; | 
|  | 24 | +import org.elasticsearch.plugins.Plugin; | 
|  | 25 | +import org.elasticsearch.test.ESIntegTestCase; | 
|  | 26 | +import org.elasticsearch.xpack.core.XPackField; | 
|  | 27 | +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; | 
|  | 28 | +import org.elasticsearch.xpack.inference.Utils; | 
|  | 29 | +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; | 
|  | 30 | +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; | 
|  | 31 | +import org.junit.Before; | 
|  | 32 | + | 
|  | 33 | +import java.util.Arrays; | 
|  | 34 | +import java.util.Collection; | 
|  | 35 | +import java.util.HashMap; | 
|  | 36 | +import java.util.List; | 
|  | 37 | +import java.util.Locale; | 
|  | 38 | +import java.util.Map; | 
|  | 39 | + | 
|  | 40 | +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; | 
|  | 41 | +import static org.hamcrest.Matchers.containsString; | 
|  | 42 | +import static org.hamcrest.Matchers.instanceOf; | 
|  | 43 | + | 
|  | 44 | +public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase { | 
|  | 45 | +    public static final String INDEX_NAME = "test-index"; | 
|  | 46 | + | 
|  | 47 | +    private final boolean useLegacyFormat; | 
|  | 48 | + | 
|  | 49 | +    public ShardBulkInferenceActionFilterBasicLicenseIT(boolean useLegacyFormat) { | 
|  | 50 | +        this.useLegacyFormat = useLegacyFormat; | 
|  | 51 | +    } | 
|  | 52 | + | 
|  | 53 | +    @ParametersFactory | 
|  | 54 | +    public static Iterable<Object[]> parameters() { | 
|  | 55 | +        return List.of(new Object[] { true }, new Object[] { false }); | 
|  | 56 | +    } | 
|  | 57 | + | 
|  | 58 | +    @Before | 
|  | 59 | +    public void setup() throws Exception { | 
|  | 60 | +        Utils.storeSparseModel(client()); | 
|  | 61 | +        Utils.storeDenseModel( | 
|  | 62 | +            client(), | 
|  | 63 | +            randomIntBetween(1, 100), | 
|  | 64 | +            // dot product means that we need normalized vectors; it's not worth doing that in this test | 
|  | 65 | +            randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values())), | 
|  | 66 | +            // TODO: Allow element type BIT once TestDenseInferenceServiceExtension supports it | 
|  | 67 | +            randomValueOtherThan(DenseVectorFieldMapper.ElementType.BIT, () -> randomFrom(DenseVectorFieldMapper.ElementType.values())) | 
|  | 68 | +        ); | 
|  | 69 | +    } | 
|  | 70 | + | 
|  | 71 | +    @Override | 
|  | 72 | +    protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { | 
|  | 73 | +        return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "basic").build(); | 
|  | 74 | +    } | 
|  | 75 | + | 
|  | 76 | +    @Override | 
|  | 77 | +    protected Collection<Class<? extends Plugin>> nodePlugins() { | 
|  | 78 | +        return Arrays.asList(LocalStateInferencePlugin.class); | 
|  | 79 | +    } | 
|  | 80 | + | 
|  | 81 | +    @Override | 
|  | 82 | +    public Settings indexSettings() { | 
|  | 83 | +        var builder = Settings.builder() | 
|  | 84 | +            .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10)) | 
|  | 85 | +            .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat); | 
|  | 86 | +        return builder.build(); | 
|  | 87 | +    } | 
|  | 88 | + | 
|  | 89 | +    public void testLicenseInvalidForInference() { | 
|  | 90 | +        prepareCreate(INDEX_NAME).setMapping( | 
|  | 91 | +            String.format( | 
|  | 92 | +                Locale.ROOT, | 
|  | 93 | +                """ | 
|  | 94 | +                    { | 
|  | 95 | +                        "properties": { | 
|  | 96 | +                            "sparse_field": { | 
|  | 97 | +                                "type": "semantic_text", | 
|  | 98 | +                                "inference_id": "%s" | 
|  | 99 | +                            }, | 
|  | 100 | +                            "dense_field": { | 
|  | 101 | +                                "type": "semantic_text", | 
|  | 102 | +                                "inference_id": "%s" | 
|  | 103 | +                            } | 
|  | 104 | +                        } | 
|  | 105 | +                    } | 
|  | 106 | +                    """, | 
|  | 107 | +                TestSparseInferenceServiceExtension.TestInferenceService.NAME, | 
|  | 108 | +                TestDenseInferenceServiceExtension.TestInferenceService.NAME | 
|  | 109 | +            ) | 
|  | 110 | +        ).get(); | 
|  | 111 | + | 
|  | 112 | +        BulkRequestBuilder bulkRequest = client().prepareBulk(); | 
|  | 113 | +        int totalBulkReqs = randomIntBetween(2, 100); | 
|  | 114 | +        for (int i = 0; i < totalBulkReqs; i++) { | 
|  | 115 | +            Map<String, Object> source = new HashMap<>(); | 
|  | 116 | +            source.put("sparse_field", randomSemanticTextInput()); | 
|  | 117 | +            source.put("dense_field", randomSemanticTextInput()); | 
|  | 118 | + | 
|  | 119 | +            bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source)); | 
|  | 120 | +        } | 
|  | 121 | + | 
|  | 122 | +        BulkResponse bulkResponse = bulkRequest.get(); | 
|  | 123 | +        for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) { | 
|  | 124 | +            assertTrue(bulkItemResponse.isFailed()); | 
|  | 125 | +            assertThat(bulkItemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class)); | 
|  | 126 | +            assertThat( | 
|  | 127 | +                bulkItemResponse.getFailure().getCause().getMessage(), | 
|  | 128 | +                containsString(Strings.format("current license is non-compliant for [%s]", XPackField.INFERENCE)) | 
|  | 129 | +            ); | 
|  | 130 | +        } | 
|  | 131 | +    } | 
|  | 132 | + | 
|  | 133 | +    public void testNullSourceSucceeds() { | 
|  | 134 | +        prepareCreate(INDEX_NAME).setMapping( | 
|  | 135 | +            String.format( | 
|  | 136 | +                Locale.ROOT, | 
|  | 137 | +                """ | 
|  | 138 | +                    { | 
|  | 139 | +                        "properties": { | 
|  | 140 | +                            "sparse_field": { | 
|  | 141 | +                                "type": "semantic_text", | 
|  | 142 | +                                "inference_id": "%s" | 
|  | 143 | +                            }, | 
|  | 144 | +                            "dense_field": { | 
|  | 145 | +                                "type": "semantic_text", | 
|  | 146 | +                                "inference_id": "%s" | 
|  | 147 | +                            } | 
|  | 148 | +                        } | 
|  | 149 | +                    } | 
|  | 150 | +                    """, | 
|  | 151 | +                TestSparseInferenceServiceExtension.TestInferenceService.NAME, | 
|  | 152 | +                TestDenseInferenceServiceExtension.TestInferenceService.NAME | 
|  | 153 | +            ) | 
|  | 154 | +        ).get(); | 
|  | 155 | + | 
|  | 156 | +        BulkRequestBuilder bulkRequest = client().prepareBulk(); | 
|  | 157 | +        int totalBulkReqs = randomIntBetween(2, 100); | 
|  | 158 | +        Map<String, Object> source = new HashMap<>(); | 
|  | 159 | +        source.put("sparse_field", null); | 
|  | 160 | +        source.put("dense_field", null); | 
|  | 161 | +        for (int i = 0; i < totalBulkReqs; i++) { | 
|  | 162 | +            bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source)); | 
|  | 163 | +        } | 
|  | 164 | + | 
|  | 165 | +        BulkResponse bulkResponse = bulkRequest.get(); | 
|  | 166 | +        assertFalse(bulkResponse.hasFailures()); | 
|  | 167 | +    } | 
|  | 168 | +} | 
0 commit comments