Skip to content

Commit da4c856

Browse files
authored
[NOID] Fixes #1596: Change key/secret to optional in apoc.nlp calls for AWS (#4062) (#4117)
1 parent 20cb235 commit da4c856

File tree

4 files changed

+224
-36
lines changed

4 files changed

+224
-36
lines changed

full/src/main/kotlin/apoc/nlp/aws/AWSProcedures.kt

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import apoc.nlp.NLPHelperFunctions
2323
import apoc.nlp.NLPHelperFunctions.getNodeProperty
2424
import apoc.nlp.NLPHelperFunctions.keyPhraseRelationshipType
2525
import apoc.nlp.NLPHelperFunctions.partition
26-
import apoc.nlp.NLPHelperFunctions.verifyKey
2726
import apoc.nlp.NLPHelperFunctions.verifyNodeProperty
2827
import apoc.nlp.NLPHelperFunctions.verifySource
2928
import apoc.result.NodeWithMapResult
@@ -59,8 +58,6 @@ class AWSProcedures {
5958
verifySource(source)
6059
val nodeProperty = getNodeProperty(config)
6160
verifyNodeProperty(source, nodeProperty)
62-
verifyKey(config, "key")
63-
verifyKey(config, "secret")
6461

6562
val client: AWSClient = awsClient(config)
6663

@@ -78,8 +75,6 @@ class AWSProcedures {
7875
verifySource(source)
7976
val nodeProperty = getNodeProperty(config)
8077
verifyNodeProperty(source, nodeProperty)
81-
verifyKey(config, "key")
82-
verifyKey(config, "secret")
8378

8479
val client = awsClient(config)
8580
val relationshipType = NLPHelperFunctions.entityRelationshipType(config)
@@ -103,8 +98,6 @@ class AWSProcedures {
10398
verifySource(source)
10499
val nodeProperty = getNodeProperty(config)
105100
verifyNodeProperty(source, nodeProperty)
106-
verifyKey(config, "key")
107-
verifyKey(config, "secret")
108101

109102
val client: AWSClient = awsClient(config)
110103

@@ -124,8 +117,6 @@ class AWSProcedures {
124117
verifySource(source)
125118
val nodeProperty = getNodeProperty(config)
126119
verifyNodeProperty(source, nodeProperty)
127-
verifyKey(config, "key")
128-
verifyKey(config, "secret")
129120

130121
val client = awsClient(config)
131122
val relationshipType = keyPhraseRelationshipType(config)
@@ -149,8 +140,6 @@ class AWSProcedures {
149140
verifySource(source)
150141
val nodeProperty = getNodeProperty(config)
151142
verifyNodeProperty(source, nodeProperty)
152-
verifyKey(config, "key")
153-
verifyKey(config, "secret")
154143

155144
val client: AWSClient = awsClient(config)
156145

@@ -170,8 +159,6 @@ class AWSProcedures {
170159
verifySource(source)
171160
val nodeProperty = getNodeProperty(config)
172161
verifyNodeProperty(source, nodeProperty)
173-
verifyKey(config, "key")
174-
verifyKey(config, "secret")
175162

176163
val client = awsClient(config)
177164
val storeGraph: Boolean = config.getOrDefault("write", false) as Boolean

full/src/main/kotlin/apoc/nlp/aws/RealAWSClient.kt

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,46 +20,75 @@ package apoc.nlp.aws
2020

2121
import apoc.result.MapResult
2222
import apoc.util.JsonUtil
23-
import com.amazonaws.auth.AWSStaticCredentialsProvider
24-
import com.amazonaws.auth.BasicAWSCredentials
23+
import com.amazonaws.auth.*
2524
import com.amazonaws.services.comprehend.AmazonComprehendClientBuilder
26-
import com.amazonaws.services.comprehend.model.BatchDetectEntitiesRequest
27-
import com.amazonaws.services.comprehend.model.BatchDetectEntitiesResult
28-
import com.amazonaws.services.comprehend.model.BatchDetectKeyPhrasesRequest
29-
import com.amazonaws.services.comprehend.model.BatchDetectKeyPhrasesResult
30-
import com.amazonaws.services.comprehend.model.BatchDetectSentimentRequest
31-
import com.amazonaws.services.comprehend.model.BatchDetectSentimentResult
25+
import com.amazonaws.services.comprehend.model.*
3226
import org.neo4j.graphdb.Node
3327
import org.neo4j.logging.Log
3428

3529
class RealAWSClient(config: Map<String, Any>, private val log: Log) : AWSClient {
36-
private val apiKey = config["key"].toString()
37-
private val apiSecret = config["secret"].toString()
30+
companion object {
31+
val missingCredentialError = """
32+
Error during AWS credentials retrieving.
33+
Make sure the key ID and the Secret Key are defined via `key` and `secret` parameters
34+
or via one of these ways: https://docs.aws.amazon.com/AWSJavaSDK/latest/javadoc/com/amazonaws/auth/DefaultAWSCredentialsProviderChain.html:
35+
"""
36+
}
37+
private val apiKey = config["key"]?.toString()
38+
private val apiSecret = config["secret"]?.toString()
39+
private val apiSessionToken = config["token"].toString()
3840
private val region = config.getOrDefault("region", "us-east-1").toString()
3941
private val language = config.getOrDefault("language", "en").toString()
4042
private val nodeProperty = config.getOrDefault("nodeProperty", "text").toString()
4143

4244
private val awsClient = AmazonComprehendClientBuilder.standard()
43-
.withCredentials(AWSStaticCredentialsProvider(BasicAWSCredentials(apiKey, apiSecret)))
45+
.withCredentials(awsStaticCredentialsProvider())
4446
.withRegion(region)
4547
.build()
4648

47-
override fun entities(data: List<Node>, batchId: Int): BatchDetectEntitiesResult? {
48-
val convertedData = convertInput(data)
49-
val batch = BatchDetectEntitiesRequest().withTextList(convertedData).withLanguageCode(language)
50-
return awsClient.batchDetectEntities(batch)
49+
private fun awsStaticCredentialsProvider(): AWSCredentialsProvider {
50+
return if (!apiKey.isNullOrEmpty() && !apiSecret.isNullOrEmpty()) {
51+
AWSStaticCredentialsProvider(getAwsBasicCredentials())
52+
} else {
53+
DefaultAWSCredentialsProviderChain()
54+
}
55+
}
56+
57+
private fun getAwsBasicCredentials() : AWSCredentials = if (apiSessionToken.isEmpty()) {
58+
BasicAWSCredentials(apiKey, apiSecret)
59+
} else {
60+
BasicSessionCredentials(apiKey, apiSecret, apiSessionToken)
61+
}
62+
63+
64+
override fun entities(data: List<Node>, batchId: Int): BatchDetectEntitiesResult? {
65+
try {
66+
val convertedData = convertInput(data)
67+
val batch = BatchDetectEntitiesRequest().withTextList(convertedData).withLanguageCode(language)
68+
return awsClient.batchDetectEntities(batch)
69+
} catch (e: Exception) {
70+
throw RuntimeException(missingCredentialError + e)
71+
}
5172
}
5273

5374
override fun keyPhrases(data: List<Node>, batchId: Int): BatchDetectKeyPhrasesResult? {
54-
val convertedData = convertInput(data)
55-
val batch = BatchDetectKeyPhrasesRequest().withTextList(convertedData).withLanguageCode(language)
56-
return awsClient.batchDetectKeyPhrases(batch)
75+
try {
76+
val convertedData = convertInput(data)
77+
val batch = BatchDetectKeyPhrasesRequest().withTextList(convertedData).withLanguageCode(language)
78+
return awsClient.batchDetectKeyPhrases(batch)
79+
} catch (e: Exception) {
80+
throw RuntimeException(missingCredentialError + e)
81+
}
5782
}
5883

5984
override fun sentiment(data: List<Node>, batchId: Int): BatchDetectSentimentResult? {
60-
val convertedData = convertInput(data)
61-
val batch = BatchDetectSentimentRequest().withTextList(convertedData).withLanguageCode(language)
62-
return awsClient.batchDetectSentiment(batch)
85+
try {
86+
val convertedData = convertInput(data)
87+
val batch = BatchDetectSentimentRequest().withTextList(convertedData).withLanguageCode(language)
88+
return awsClient.batchDetectSentiment(batch)
89+
} catch (e: Exception) {
90+
throw RuntimeException(missingCredentialError + e)
91+
}
6392
}
6493

6594
fun sentiment(data: List<Node>, config: Map<String, Any?>): List<MapResult> {
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
package apoc.nlp.aws
2+
3+
import apoc.util.TestUtil
4+
import com.amazonaws.SDKGlobalConfiguration.ACCESS_KEY_ENV_VAR
5+
import com.amazonaws.SDKGlobalConfiguration.SECRET_KEY_ENV_VAR
6+
import org.junit.Assert.assertTrue
7+
import org.junit.Assume.assumeNotNull
8+
import org.junit.BeforeClass
9+
import org.junit.ClassRule
10+
import org.junit.Test
11+
import org.neo4j.graphdb.Result
12+
import org.neo4j.test.rule.ImpermanentDbmsRule
13+
14+
15+
/**
16+
* To execute tests, set these environment variables:
17+
* AWS_ACCESS_KEY_ID=<apiKey>;AWS_SECRET_KEY=<secretKey>
18+
*/
19+
class AWSProceduresAPIWithEnvVarsTest {
20+
companion object {
21+
private val apiKey: String? = System.getenv(ACCESS_KEY_ENV_VAR)
22+
private val apiSecret: String? = System.getenv(SECRET_KEY_ENV_VAR)
23+
24+
@ClassRule
25+
@JvmField
26+
val neo4j = ImpermanentDbmsRule()
27+
28+
@BeforeClass
29+
@JvmStatic
30+
fun beforeClass() {
31+
neo4j.executeTransactionally("""
32+
CREATE (:Article {
33+
uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/",
34+
body: "These days I’m rarely more than a few feet away from my Nintendo Switch and I play board games, card games and role playing games with friends at least once or twice a week. I’ve even organised lunch-time Mario Kart 8 tournaments between the Neo4j European offices!"
35+
});""")
36+
37+
neo4j.executeTransactionally("""
38+
CREATE (:Article {
39+
uri: "https://en.wikipedia.org/wiki/Nintendo_Switch",
40+
body: "The Nintendo Switch is a video game console developed by Nintendo, released worldwide in most regions on March 3, 2017. It is a hybrid console that can be used as a home console and portable device. The Nintendo Switch was unveiled on October 20, 2016. Nintendo offers a Joy-Con Wheel, a small steering wheel-like unit that a Joy-Con can slot into, allowing it to be used for racing games such as Mario Kart 8."
41+
});
42+
""")
43+
44+
assumeNotNull(apiKey, apiSecret)
45+
TestUtil.registerProcedure(neo4j, AWSProcedures::class.java)
46+
}
47+
}
48+
49+
@Test
50+
fun `should extract entities in stream mode`() {
51+
neo4j.executeTransactionally("""
52+
MATCH (a:Article {uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/"})
53+
CALL apoc.nlp.aws.entities.stream(a, {
54+
nodeProperty: "body"
55+
})
56+
YIELD value
57+
UNWIND value.entities AS result
58+
RETURN result;
59+
""", mapOf()) {
60+
assertStreamWithScoreResult(it)
61+
}
62+
}
63+
64+
@Test
65+
fun `should extract entities in graph mode`() {
66+
neo4j.executeTransactionally("""
67+
MATCH (a:Article {uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/"})
68+
CALL apoc.nlp.aws.entities.graph(a, {
69+
nodeProperty: "body",
70+
writeRelationshipType: "ENTITY"
71+
})
72+
YIELD graph AS g
73+
RETURN g;
74+
""", mapOf()) {
75+
assertGraphResult(it)
76+
}
77+
}
78+
79+
@Test
80+
fun `should extract key phrases in stream mode`() {
81+
neo4j.executeTransactionally("""
82+
MATCH (a:Article {uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/"})
83+
CALL apoc.nlp.aws.keyPhrases.stream(a, {
84+
nodeProperty: "body"
85+
})
86+
YIELD value
87+
UNWIND value.keyPhrases AS result
88+
RETURN result
89+
""", mapOf()) {
90+
assertStreamWithScoreResult(it)
91+
}
92+
}
93+
94+
@Test
95+
fun `should extract key phrases in graph mode`() {
96+
neo4j.executeTransactionally("""
97+
MATCH (a:Article {uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/"})
98+
CALL apoc.nlp.aws.keyPhrases.graph(a, {
99+
nodeProperty: "body",
100+
writeRelationshipType: "KEY_PHRASE",
101+
write: true
102+
})
103+
YIELD graph AS g
104+
RETURN g;
105+
""", mapOf()) {
106+
assertGraphResult(it)
107+
}
108+
}
109+
110+
@Test
111+
fun `should extract sentiment in stream mode`() {
112+
neo4j.executeTransactionally("""
113+
MATCH (a:Article {uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/"})
114+
CALL apoc.nlp.aws.sentiment.stream(a, {
115+
nodeProperty: "body"
116+
})
117+
YIELD value
118+
RETURN value AS result;
119+
""", mapOf()) {
120+
assertSentimentScoreResult(it)
121+
}
122+
}
123+
124+
@Test
125+
fun `should extract sentiment in graph mode`() {
126+
neo4j.executeTransactionally("""
127+
MATCH (a:Article {uri: "https://neo4j.com/blog/pokegraph-gotta-graph-em-all/"})
128+
CALL apoc.nlp.aws.sentiment.graph(a, {
129+
nodeProperty: "body",
130+
write: true
131+
})
132+
YIELD graph AS g
133+
UNWIND g.nodes AS node
134+
RETURN node {.uri, .sentiment, .sentimentScore} AS result;
135+
""", mapOf()) {
136+
assertSentimentScoreResult(it)
137+
}
138+
}
139+
140+
private fun assertStreamWithScoreResult(it: Result) {
141+
val asSequence = it.asSequence().toList()
142+
assertTrue(asSequence.isNotEmpty())
143+
144+
asSequence.forEach {
145+
val entity: Map<String, Any> = it["result"] as Map<String, Any>
146+
assertTrue(entity.containsKey("score"))
147+
}
148+
}
149+
150+
private fun assertGraphResult(it: Result) {
151+
val asSequence = it.asSequence().toList()
152+
assertTrue(asSequence.isNotEmpty())
153+
154+
asSequence.forEach {
155+
val entity: Map<String, Any> = it["g"] as Map<String, Any>
156+
assertTrue(entity.containsKey("nodes"))
157+
assertTrue(entity.containsKey("relationships"))
158+
}
159+
}
160+
161+
private fun assertSentimentScoreResult(it: Result) {
162+
val asSequence = it.asSequence().toList()
163+
assertTrue(asSequence.isNotEmpty())
164+
165+
asSequence.forEach {
166+
val entity: Map<String, Any> = it["result"] as Map<String, Any>
167+
assertTrue(entity.containsKey("sentimentScore"))
168+
}
169+
}
170+
}
171+

full/src/test/kotlin/apoc/nlp/aws/AWSProceduresErrorsTest.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
*/
1919
package apoc.nlp.aws
2020

21+
import apoc.nlp.aws.RealAWSClient.Companion.missingCredentialError
2122
import apoc.util.TestUtil
2223
import org.hamcrest.CoreMatchers.containsString
2324
import org.junit.AfterClass
@@ -94,7 +95,7 @@ class AWSProceduresErrorsTest {
9495
println(it.resultAsString())
9596
}
9697
}
97-
assertThat(exception.message, containsString("java.lang.IllegalArgumentException: Missing parameter `key`"))
98+
assertThat(exception.message, containsString(missingCredentialError))
9899
}
99100

100101
@Test
@@ -111,6 +112,6 @@ class AWSProceduresErrorsTest {
111112
println(it.resultAsString())
112113
}
113114
}
114-
assertThat(exception.message, containsString("java.lang.IllegalArgumentException: Missing parameter `secret`"))
115+
assertThat(exception.message, containsString(missingCredentialError))
115116
}
116117
}

0 commit comments

Comments
 (0)