Skip to content

Commit 783d1f8

Browse files
authored
Support callback header (#1808)
* Support runscript callbackHeader
1 parent b80cb25 commit 783d1f8

File tree

3 files changed

+58
-12
lines changed

3 files changed

+58
-12
lines changed

streamingpro-core/src/main/java/tech/mlsql/crawler/RestUtils.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ import net.csdn.common.path.Url
44
import net.csdn.modules.transport.HttpTransportService.SResponse
55
import net.csdn.modules.transport.{DefaultHttpTransportService, HttpTransportService}
66
import org.apache.commons.lang3.exception.ExceptionUtils
7-
import org.apache.http.{HttpEntity, HttpResponse}
87
import org.apache.http.client.entity.UrlEncodedFormEntity
98
import org.apache.http.client.fluent.{Form, Request}
109
import org.apache.http.entity.ContentType
1110
import org.apache.http.entity.mime.{HttpMultipartMode, MultipartEntityBuilder}
1211
import org.apache.http.message.BasicNameValuePair
1312
import org.apache.http.util.EntityUtils
13+
import org.apache.http.{HttpEntity, HttpResponse}
1414
import streaming.dsl.ScriptSQLExec
1515
import streaming.log.WowLog
1616
import tech.mlsql.common.JsonUtils
@@ -22,17 +22,21 @@ import tech.mlsql.tool.{HDFSOperatorV2, Templates2}
2222
import java.nio.charset.Charset
2323
import scala.annotation.tailrec
2424
import scala.collection.JavaConversions._
25-
import scala.util.control.Breaks.{break, breakable}
2625

2726
object RestUtils extends Logging with WowLog {
28-
def httpClientPost(urlString: String, data: Map[String, String]): HttpResponse = {
27+
def httpClientPost(urlString: String, data: Map[String, String], headers: Map[String, String]): HttpResponse = {
2928
val nameValuePairs = data.map { case (name, value) =>
3029
new BasicNameValuePair(name, value)
3130
}.toList
3231

33-
Request.Post(urlString)
32+
val req = Request.Post(urlString)
3433
.addHeader("Content-Type", "application/x-www-form-urlencoded")
35-
.body(new UrlEncodedFormEntity(nameValuePairs, DefaultHttpTransportService.charset))
34+
35+
headers foreach { case (name, value) =>
36+
req.setHeader(name, value)
37+
}
38+
39+
req.body(new UrlEncodedFormEntity(nameValuePairs, DefaultHttpTransportService.charset))
3640
.execute()
3741
.returnResponse()
3842
}

streamingpro-it/src/test/scala/tech/mlsql/it/ByzerScriptTestSuite.scala

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package tech.mlsql.it
22

3+
import net.csdn.modules.transport.DefaultHttpTransportService
4+
import org.apache.http.HttpEntity
5+
import org.apache.http.util.EntityUtils
36
import tech.mlsql.common.utils.log.Logging
47
import tech.mlsql.crawler.RestUtils
58
import tech.mlsql.it.contiainer.ByzerCluster
@@ -8,6 +11,7 @@ import tech.mlsql.it.utils.DockerUtils.getCurProjectRootPath
811

912
import java.io.File
1013
import java.util.UUID
14+
import scala.collection.mutable
1115

1216
/**
1317
* 23/02/2022 hellozepp(lisheng.zhanglin@163.com)
@@ -54,11 +58,27 @@ class ByzerScriptTestSuite extends LocalBaseTestSuite with Logging {
5458
})
5559
}
5660

57-
def runScript(url: String, user: String, code: String): (Int, String) = {
61+
def runScript(url: String, user: String, code: String, callbackHeader: String = ""): (Int, String) = {
5862
val jobName = UUID.randomUUID().toString
63+
val params = mutable.Map("sql" -> code, "owner" -> user,
64+
"jobName" -> jobName, "sessionPerUser" -> "true", "sessionPerRequest" -> "true")
65+
if (callbackHeader != "") params.put("callbackHeader", callbackHeader)
5966
logInfo(s"The test submits a script to the container through Rest, url:$url, sql:$code")
60-
val (status, result) = RestUtils.rest_request_string(url, "post", Map("sql" -> code, "owner" -> user,
61-
"jobName" -> jobName, "sessionPerUser" -> "true", "sessionPerRequest" -> "true"),
67+
val (status, result) = RestUtils.rest_request_string(url, "post", params.toMap,
68+
Map("Content-Type" -> "application/x-www-form-urlencoded"), Map("socket-timeout" -> "1800s",
69+
"connect-timeout" -> "1800s", "retry" -> "1")
70+
)
71+
logInfo(s"status:$status,result:$result")
72+
(status, result)
73+
}
74+
75+
def runScriptWithHeader(url: String, user: String, code: String, callbackHeader: String = ""): (Int, HttpEntity) = {
76+
val jobName = UUID.randomUUID().toString
77+
val params = mutable.Map("sql" -> code, "owner" -> user,
78+
"jobName" -> jobName, "sessionPerUser" -> "true", "sessionPerRequest" -> "true")
79+
if (callbackHeader != "") params.put("callbackHeader", callbackHeader)
80+
logInfo(s"The test submits a script to the container through Rest, url:$url, sql:$code")
81+
val (status, result) = RestUtils.rest_request(url, "post", params.toMap,
6282
Map("Content-Type" -> "application/x-www-form-urlencoded"), Map("socket-timeout" -> "1800s",
6383
"connect-timeout" -> "1800s", "retry" -> "1")
6484
)
@@ -82,7 +102,7 @@ class ByzerScriptTestSuite extends LocalBaseTestSuite with Logging {
82102
val cluster: ByzerCluster = setupCluster()
83103
val hadoopContainer = cluster.hadoopContainer
84104
val byzerLangContainer = cluster.byzerLangContainer
85-
val javaContainer = cluster.byzerLangContainer.container
105+
val javaContainer = byzerLangContainer.container
86106
url = s"http://${javaContainer.getHost}:${javaContainer.getMappedPort(9003)}/run/script"
87107

88108
test("javaContainer") {
@@ -101,6 +121,19 @@ class ByzerScriptTestSuite extends LocalBaseTestSuite with Logging {
101121
}
102122

103123
test("Execute yarn sql file") {
124+
try {
125+
val (_, result) = runScriptWithHeader(url, user, "select 1 as a,'jack' as b as bbc;",
126+
"""{"Authorization":"Bearer acc"}""")
127+
val _result = EntityUtils.toString(result, DefaultHttpTransportService.charset)
128+
println("With callbackHeader result:" + _result)
129+
assert(_result === "[{\"a\":1,\"b\":\"jack\"}]")
130+
} catch {
131+
case _: Exception =>
132+
val res = "callbackHeader should be returned normally in the byzer callback!"
133+
logError(res)
134+
throw new RuntimeException(res)
135+
}
136+
104137
TestManager.testCases.foreach(testCase => {
105138
try {
106139
val (status, result) = runScript(url, user, testCase.sql)
@@ -110,7 +143,6 @@ class ByzerScriptTestSuite extends LocalBaseTestSuite with Logging {
110143
TestManager.acceptRest(testCase, 500, null, e)
111144
}
112145
})
113-
114146
TestManager.report()
115147
}
116148

streamingpro-mlsql/src/main/java/streaming/rest/RestController.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import org.apache.spark.sql.mlsql.session.{MLSQLSparkSession, SparkSessionCacheM
3838
import org.apache.spark.{MLSQLConf, SparkInstanceService}
3939
import tech.mlsql.MLSQLEnvKey
4040
import tech.mlsql.app.{CustomController, ResultResp}
41+
import tech.mlsql.common.JsonUtils
4142
import tech.mlsql.common.utils.log.Logging
4243
import tech.mlsql.common.utils.serder.json.JSONTool
4344
import tech.mlsql.crawler.RestUtils
@@ -105,6 +106,7 @@ class RestController extends ApplicationController with WowLog with Logging {
105106
new Parameter(name = "sessionPerRequest", required = false, description = "by default false", `type` = "boolean", allowEmptyValue = false),
106107
new Parameter(name = "async", required = false, description = "If set true ,please also provide a callback url use `callback` parameter and the job will run in background and the API will return. default: false", `type` = "boolean", allowEmptyValue = false),
107108
new Parameter(name = "callback", required = false, description = "Used when async is set true. callback is a url. default: false", `type` = "string", allowEmptyValue = false),
109+
new Parameter(name = "callbackHeader", required = false, description = "Provide a jsonString parameter to set the header parameter of the callback request. default: false", `type` = "string", allowEmptyValue = false),
108110
new Parameter(name = "maxRetries", required = false, description = "Max retries of request callback.", `type` = "int", allowEmptyValue = false),
109111
new Parameter(name = "skipInclude", required = false, description = "disable include statement. default: false", `type` = "boolean", allowEmptyValue = false),
110112
new Parameter(name = "skipAuth", required = false, description = "disable table authorize . default: true", `type` = "boolean", allowEmptyValue = false),
@@ -147,6 +149,12 @@ class RestController extends ApplicationController with WowLog with Logging {
147149
if (paramAsBoolean("async", false)) {
148150
JobManager.asyncRun(sparkSession, jobInfo, () => {
149151
val urlString = param("callback")
152+
val callbackHeaderString = param("callbackHeader")
153+
var callbackHeader = Map[String,String]()
154+
if (callbackHeaderString != null && callbackHeaderString.nonEmpty){
155+
callbackHeader = JsonUtils.fromJson[Map[String,String]](callbackHeaderString)
156+
}
157+
150158
val maxTries = Math.max(0, paramAsInt("maxRetries", -1)) + 1
151159
try {
152160
ScriptSQLExec.parse(param("sql"), context,
@@ -161,7 +169,8 @@ class RestController extends ApplicationController with WowLog with Logging {
161169
RestUtils.httpClientPost(urlString,
162170
Map("stat" -> s"""succeeded""",
163171
"res" -> outputResult,
164-
"jobInfo" -> JSONTool.toJsonStr(jobInfo))),
172+
"jobInfo" -> JSONTool.toJsonStr(jobInfo)),
173+
callbackHeader),
165174
HttpStatus.SC_OK == _.getStatusLine.getStatusCode,
166175
response => logger.error(s"Succeeded SQL callback request failed after ${maxTries} attempts, " +
167176
s"the last response status is: ${response.getStatusLine.getStatusCode}.")
@@ -178,7 +187,8 @@ class RestController extends ApplicationController with WowLog with Logging {
178187
RestUtils.httpClientPost(urlString,
179188
Map("stat" -> s"""failed""",
180189
"msg" -> (e.getMessage + "\n" + msgBuffer.mkString("\n")),
181-
"jobInfo" -> JSONTool.toJsonStr(jobInfo))),
190+
"jobInfo" -> JSONTool.toJsonStr(jobInfo)),
191+
callbackHeader),
182192
HttpStatus.SC_OK == _.getStatusLine.getStatusCode,
183193
response => logger.error(s"Fail SQL callback request failed after ${maxTries} attempts, " +
184194
s"the last response status is: ${response.getStatusLine.getStatusCode}.")

0 commit comments

Comments
 (0)