Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/ArduinoIoTCloudTCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ int ArduinoIoTCloudTCP::begin(bool const enable_watchdog, String brokerAddress,

#if OTA_ENABLED && !defined(OFFLOADED_DOWNLOAD)
_ota.setClient(&_otaClient);
if (_connection->getInterface() == NetworkAdapter::ETHERNET) {
_ota.setFetchMode(OTADefaultCloudProcessInterface::OtaFetchChunk);
}
#endif // OTA_ENABLED && !defined(OFFLOADED_DOWNLOAD)

#if OTA_ENABLED && defined(OTA_BASIC_AUTH)
Expand Down
132 changes: 89 additions & 43 deletions src/ota/interface/OTAInterfaceDefault.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ OTADefaultCloudProcessInterface::OTADefaultCloudProcessInterface(MessageStream *
, client(client)
, http_client(nullptr)
, username(nullptr), password(nullptr)
, fetchMode(OtaFetchTime)
, context(nullptr) {
}

Expand All @@ -41,57 +42,43 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::startOTA() {
}
);

// make the http get request
// check url
if(strcmp(context->parsed_url.schema(), "https") == 0) {
http_client = new HttpClient(*client, context->parsed_url.host(), context->parsed_url.port());
} else {
return UrlParseErrorFail;
}

http_client->beginRequest();
auto res = http_client->get(context->parsed_url.path());

if(username != nullptr && password != nullptr) {
http_client->sendBasicAuth(username, password);
}

http_client->endRequest();

if(res == HTTP_ERROR_CONNECTION_FAILED) {
DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"",
context->parsed_url.host(), context->parsed_url.port());
return ServerConnectErrorFail;
} else if(res == HTTP_ERROR_TIMED_OUT) {
DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", OTACloudProcessInterface::context->url);
return OtaHeaderTimeoutFail;
} else if(res != HTTP_SUCCESS) {
DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", res, OTACloudProcessInterface::context->url);
return OtaDownloadFail;
}

int statusCode = http_client->responseStatusCode();

if(statusCode != 200) {
DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", OTACloudProcessInterface::context->url, statusCode);
return HttpResponseFail;
}
// make the http get request
requestOta(OtaFetchTime);

// The following call is required to save the header value , keep it
if(http_client->contentLength() == HttpClient::kNoContentLengthHeader) {
context->contentLength = http_client->contentLength();
if(context->contentLength == HttpClient::kNoContentLengthHeader) {
DEBUG_VERBOSE("OTA ERROR: the response header doesn't contain \"ContentLength\" field");
return HttpHeaderErrorFail;
}

context->lastReportTime = millis();

DEBUG_VERBOSE("OTA file length: %d", context->contentLength);
return Fetch;
}

OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
OTACloudProcessInterface::State res = Fetch;
int http_res = 0;
uint32_t start = millis();

if(fetchMode == OtaFetchChunk) {
res = requestOta(OtaFetchChunk);
}

context->downloadedChunkSize = 0;
context->downloadedChunkStartTime = millis();

if(res != Fetch) {
goto exit;
}

/* download chunked or timed */
do {
if(!http_client->connected()) {
res = OtaDownloadFail;
Expand All @@ -104,7 +91,7 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
continue;
}

http_res = http_client->read(context->buffer, context->buf_len);
int http_res = http_client->read(context->buffer, context->bufLen);

if(http_res < 0) {
DEBUG_VERBOSE("OTA ERROR: Download read error %d", http_res);
Expand All @@ -119,8 +106,10 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
res = ErrorWriteUpdateFileFail;
goto exit;
}
} while((context->downloadState == OtaDownloadFile || context->downloadState == OtaDownloadHeader) &&
millis() - start < downloadTime);

context->downloadedChunkSize += http_res;

} while(context->downloadState < OtaDownloadCompleted && fetchMore());

// TODO verify that the information present in the ota header match the info in context
if(context->downloadState == OtaDownloadCompleted) {
Expand Down Expand Up @@ -153,13 +142,69 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
return res;
}

void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len) {
OTACloudProcessInterface::State OTADefaultCloudProcessInterface::requestOta(OTAFetchMode mode) {
int http_res = 0;

/* stop connected client */
http_client->stop();

/* request chunk */
http_client->beginRequest();
http_res = http_client->get(context->parsed_url.path());

if(username != nullptr && password != nullptr) {
http_client->sendBasicAuth(username, password);
}

if(mode == OtaFetchChunk) {
char range[128] = {0};
size_t rangeSize = context->downloadedSize + maxChunkSize > context->contentLength ? context->contentLength - context->downloadedSize : maxChunkSize;
sprintf(range, "bytes=%d-%d", context->downloadedSize, context->downloadedSize + rangeSize);
DEBUG_VERBOSE("OTA downloading range: %s", range);
http_client->sendHeader("Range", range);
}

http_client->endRequest();

if(http_res == HTTP_ERROR_CONNECTION_FAILED) {
DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"",
context->parsed_url.host(), context->parsed_url.port());
return ServerConnectErrorFail;
} else if(http_res == HTTP_ERROR_TIMED_OUT) {
DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", OTACloudProcessInterface::context->url);
return OtaHeaderTimeoutFail;
} else if(http_res != HTTP_SUCCESS) {
DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", http_res, OTACloudProcessInterface::context->url);
return OtaDownloadFail;
}

int statusCode = http_client->responseStatusCode();

if(((mode == OtaFetchChunk) && (statusCode != 206)) || ((mode == OtaFetchTime) && (statusCode != 200))) {
DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", OTACloudProcessInterface::context->url, statusCode);
return HttpResponseFail;
}

http_client->skipResponseHeaders();

return Fetch;
}

bool OTADefaultCloudProcessInterface::fetchMore() {
if (fetchMode == OtaFetchChunk) {
return context->downloadedChunkSize < maxChunkSize;
} else {
return (millis() - context->downloadedChunkStartTime) < downloadTime;
}
}

void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t bufLen) {
assert(context != nullptr); // This should never fail

for(uint8_t* cursor=(uint8_t*)buffer; cursor<buffer+buf_len; ) {
for(uint8_t* cursor=(uint8_t*)buffer; cursor<buffer+bufLen; ) {
switch(context->downloadState) {
case OtaDownloadHeader: {
const uint32_t headerLeft = context->headerCopiedBytes + buf_len <= sizeof(context->header.buf) ? buf_len : sizeof(context->header.buf) - context->headerCopiedBytes;
const uint32_t headerLeft = context->headerCopiedBytes + bufLen <= sizeof(context->header.buf) ? bufLen : sizeof(context->header.buf) - context->headerCopiedBytes;
memcpy(context->header.buf+context->headerCopiedBytes, buffer, headerLeft);
cursor += headerLeft;
context->headerCopiedBytes += headerLeft;
Expand All @@ -184,8 +229,7 @@ void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len)
break;
}
case OtaDownloadFile: {
const uint32_t contentLength = http_client->contentLength();
const uint32_t dataLeft = buf_len - (cursor-buffer);
const uint32_t dataLeft = bufLen - (cursor-buffer);
context->decoder.decompress(cursor, dataLeft); // TODO verify return value

context->calculatedCrc32 = crc_update(
Expand All @@ -198,18 +242,18 @@ void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len)
context->downloadedSize += dataLeft;

if((millis() - context->lastReportTime) > 10000) { // Report the download progress each X millisecond
DEBUG_VERBOSE("OTA Download Progress %d/%d", context->downloadedSize, contentLength);
DEBUG_VERBOSE("OTA Download Progress %d/%d", context->downloadedSize, context->contentLength);

reportStatus(context->downloadedSize);
context->lastReportTime = millis();
}

// TODO there should be no more bytes available when the download is completed
if(context->downloadedSize == contentLength) {
if(context->downloadedSize == context->contentLength) {
context->downloadState = OtaDownloadCompleted;
}

if(context->downloadedSize > contentLength) {
if(context->downloadedSize > context->contentLength) {
context->downloadState = OtaDownloadError;
}
// TODO fail if we exceed a timeout? and available is 0 (client is broken)
Expand Down Expand Up @@ -250,7 +294,9 @@ OTADefaultCloudProcessInterface::Context::Context(
, headerCopiedBytes(0)
, downloadedSize(0)
, lastReportTime(0)
, contentLength(0)
, writeError(false)
, downloadedChunkSize(0)
, decoder(putc) { }

static const uint32_t crc_table[256] = {
Expand Down
20 changes: 18 additions & 2 deletions src/ota/interface/OTAInterfaceDefault.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,36 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface {
this->password = password;
}

enum OTAFetchMode: uint8_t {
OtaFetchTime,
OtaFetchChunk
};

inline virtual void setFetchMode(OTAFetchMode mode) { this->fetchMode = mode; }

protected:
State startOTA();
State fetch();
void reset();
virtual int writeFlash(uint8_t* const buffer, size_t len) = 0;

private:
void parseOta(uint8_t* buffer, size_t buf_len);
void parseOta(uint8_t* buffer, size_t bufLen);
State requestOta(OTAFetchMode mode);
bool fetchMore();

Client* client;
HttpClient* http_client;

const char *username, *password;
OTAFetchMode fetchMode;

// The amount of time that each iteration of Fetch has to take at least
// This mitigate the issues arising from tasks run in main loop that are using all the computing time
static constexpr uint32_t downloadTime = 2000;

static constexpr size_t maxChunkSize = 1024 * 10;

enum OTADownloadState: uint8_t {
OtaDownloadHeader,
OtaDownloadFile,
Expand All @@ -74,12 +86,16 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface {
uint32_t headerCopiedBytes;
uint32_t downloadedSize;
uint32_t lastReportTime;
uint32_t contentLength;
bool writeError;

uint32_t downloadedChunkStartTime;
uint32_t downloadedChunkSize;

// LZSS decoder
LZSSDecoder decoder;

const size_t buf_len = 64;
const size_t bufLen = 64;
uint8_t buffer[64];
} *context;
};
Expand Down
Loading