Skip to content

Commit a7f55d1

Browse files
authored
add additional info (#1845)
1 parent 69d7971 commit a7f55d1

File tree

1 file changed

+38
-26
lines changed

1 file changed

+38
-26
lines changed

streamingpro-mlsql/src/main/java/tech/mlsql/datasource/impl/MLSQLRest.scala

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,22 +75,24 @@ class MLSQLRest(override val uid: String) extends MLSQLSource
7575
val skipParams = config.config.getOrElse("config.page.skip-params", "false").toBoolean
7676
val retryInterval = JavaUtils.timeStringAsMs(config.config.getOrElse("config.retry.interval", "1s"))
7777
val debug = config.config.getOrElse("config.debug", "false").toBoolean
78+
val strategy = config.config.getOrElse("config.page.error.strategy", "abort")
7879
val enableRequestCleaner = config.config.getOrElse(configEnableRequestCleaner.name, "false")
7980
ScriptSQLExec.context().execListener.addEnv("enableRestDataSourceRequestCleaner", enableRequestCleaner)
8081

81-
def _error2DataFrame(errMsg: String, statusCode: Int, session: SparkSession) : DataFrame = {
82-
session.createDataFrame(session.sparkContext.makeRDD(Seq(Row.fromSeq(Seq( errMsg.getBytes(UTF_8), statusCode))))
82+
def _error2DataFrame(errMsg: String, statusCode: Int, session: SparkSession): DataFrame = {
83+
session.createDataFrame(session.sparkContext.makeRDD(Seq(Row.fromSeq(Seq(errMsg.getBytes(UTF_8), statusCode))))
8384
, StructType(fields = Seq(
8485
StructField("content", BinaryType), StructField("status", IntegerType)
8586
)))
8687
}
8788

8889
/**
8990
* Calling http rest endpoints with retrying
90-
* @param url http url
91-
* @param skipParams if true, not adding parameters to http get urls
92-
* @maxTries The max number of attempts
93-
* @return http status code along with DataFrame
91+
*
92+
* @param url http url
93+
* @param skipParams if true, not adding parameters to http get urls
94+
* @maxTries The max number of attempts
95+
* @return http status code along with DataFrame
9496
*/
9597
def _httpWithRetrying(url: String, skipParams: Boolean, maxTries: Int): (Int, DataFrame) = {
9698
executeWithRetrying[(Int, DataFrame)](maxTries)((() => {
@@ -105,18 +107,18 @@ class MLSQLRest(override val uid: String) extends MLSQLSource
105107
} catch {
106108
// According to _http function, it throws MLSQLException if any request parameter is invalid
107109
case me: MLSQLException =>
108-
if( me.getMessage.startsWith("content-type"))
110+
if (me.getMessage.startsWith("content-type"))
109111
(415, _error2DataFrame(me.getMessage, 415, config.df.get.sparkSession))
110-
else if(me.getMessage.startsWith("HTTP method"))
111-
(405, _error2DataFrame(me.getMessage,405, config.df.get.sparkSession))
112+
else if (me.getMessage.startsWith("HTTP method"))
113+
(405, _error2DataFrame(me.getMessage, 405, config.df.get.sparkSession))
112114
else
113115
(500, _error2DataFrame(me.getMessage, 500, config.df.get.sparkSession))
114116
case e: Exception => (0, _error2DataFrame(e.getMessage, 0, config.df.get.sparkSession))
115117
}
116118
}) (),
117119
tempResp => {
118120
val succeed = tempResp._1 == 200
119-
if (! succeed) {
121+
if (!succeed) {
120122
Thread.sleep(retryInterval)
121123
}
122124
succeed
@@ -147,8 +149,8 @@ class MLSQLRest(override val uid: String) extends MLSQLSource
147149

148150
// If a user runs multiple Load Rest.`` statements in a single thread, we need to save all temp dirs
149151
context.execListener.env().get(classOf[MLSQLRest].getName) match {
150-
case Some(dirs) => context.execListener.addEnv( classOf[MLSQLRest].getName, s"${dirs},${tmpTablePath}" )
151-
case None => context.execListener.addEnv( classOf[MLSQLRest].getName, tmpTablePath )
152+
case Some(dirs) => context.execListener.addEnv(classOf[MLSQLRest].getName, s"${dirs},${tmpTablePath}")
153+
case None => context.execListener.addEnv(classOf[MLSQLRest].getName, tmpTablePath)
152154
}
153155

154156
var count = 0
@@ -157,28 +159,37 @@ class MLSQLRest(override val uid: String) extends MLSQLSource
157159
var url = config.path
158160
do {
159161
val pageFetchTime = System.currentTimeMillis()
160-
val _skipParams = if( count == 0 ) false else skipParams
162+
val _skipParams = if (count == 0) false else skipParams
161163
val (_, dataFrame) = _httpWithRetrying(url, _skipParams, maxTries)
162164
val row = dataFrame.select(F.col("content").cast(StringType), F.col("status")).head
163165
val content = row.getString(0)
164166
// Reset status
165167
status = row.getInt(1)
166168
pageStrategy.nexPage(Option(content))
167-
url = pageStrategy.pageUrl( Option(content) )
169+
url = pageStrategy.pageUrl(Option(content))
168170
hasNextPage = pageStrategy.hasNextPage(Option(content))
171+
val exceptionMsg = s"URL:${url},with response status ${status}, Have retried ${maxTries} times!\n ${content}"
172+
if (status != 200) {
173+
// if strategy is abort, then the raise error, else if the strategy is skip, just log warn the exception
174+
strategy match {
175+
case "abort" => throw new RuntimeException(exceptionMsg)
176+
case "skip" =>
177+
// First page request failed, return immediately, otherwise return saved DataFrame
178+
logWarning(exceptionMsg)
179+
if (count == 0) return dataFrame else return context.execListener.sparkSession.read.parquet(tmpTablePath)
180+
}
169181

170-
if( status != 200 ) {
171-
// First page request failed, return immediately, otherwise return saved DataFrame
172-
if( count == 0 ) return dataFrame else return context.execListener.sparkSession.read.parquet(tmpTablePath)
173182
}
174183
dataFrame.write.format("parquet").mode(SaveMode.Append).save(tmpTablePath)
184+
logInfo(s"Data from url:${url} from start:${count * maxSize} to end:${(count + 1) * maxSize - 1} " +
185+
s"is done! and dump to ${tmpTablePath}")
175186
if (debug) {
176187
logInfo(format(s"Getting Page ${count} ${url} Consume:${System.currentTimeMillis() - pageFetchTime}ms"))
177188
}
178-
if( count > 0 ) Thread.sleep( pageInterval)
189+
if (count > 0) Thread.sleep(pageInterval)
179190
count += 1
180191
}
181-
while( count < maxSize && hasNextPage && status == 200 )
192+
while (count < maxSize && hasNextPage && status == 200)
182193
context.execListener.sparkSession.read.parquet(tmpTablePath)
183194

184195
case (None, None) =>
@@ -192,11 +203,12 @@ class MLSQLRest(override val uid: String) extends MLSQLSource
192203

193204
/**
194205
* Send http request and return the result as a DataFrame
195-
* @param url http url
196-
* @param params http parameters, including content-type http method and others
197-
* @param skipParams if true, not adding parameters to http get url
198-
* @param session The Spark Session
199-
* @return DataFrame , is guaranteed to be not null, with 2 columns content: Array[String] status: Int
206+
*
207+
* @param url http url
208+
* @param params http parameters, including content-type http method and others
209+
* @param skipParams if true, not adding parameters to http get url
210+
* @param session The Spark Session
211+
* @return DataFrame , is guaranteed to be not null, with 2 columns content: Array[String] status: Int
200212
*/
201213
private def _http(url: String, params: Map[String, String], skipParams: Boolean, session: SparkSession): DataFrame = {
202214
val httpMethod = params.getOrElse(configMethod.name, "get").toLowerCase
@@ -354,11 +366,11 @@ class MLSQLRest(override val uid: String) extends MLSQLSource
354366

355367
val filePathBuf = ArrayBuffer[(String, String)]()
356368
HDFSOperatorV2.isFile(finalPath) match {
357-
case true =>
369+
case true =>
358370
filePathBuf.append((finalPath, fileName))
359371
case false if HDFSOperatorV2.isDir(finalPath) =>
360372
val listFiles = HDFSOperatorV2.listFiles(finalPath)
361-
if(listFiles.filter(_.isDirectory).size > 0) throw new MLSQLException(s"Including subdirectories is not supported")
373+
if (listFiles.filter(_.isDirectory).size > 0) throw new MLSQLException(s"Including subdirectories is not supported")
362374

363375
listFiles.filterNot(_.getPath.getName.equals("_SUCCESS"))
364376
.foreach(file => filePathBuf.append((file.getPath.toString, file.getPath.getName)))

0 commit comments

Comments
 (0)